繁体   English   中英

如何调用“tfa.losses.triplet_semihard_loss”?

[英]How does "tfa.losses.triplet_semihard_loss" get called?

在 Tensorflow 插件中,有两次提到三元组损失,一个是基类tfa.losses.triplet_semihard_loss ,另一个是tfa.losses.TripletSemiHardLoss ,它是由用户初始化的子类,进而隐式调用基类。 在属于子类的这段代码中:

    def __init__(self, margin=1.0, name=None):
        super(TripletSemiHardLoss, self).__init__(
            name=name, reduction=tf.keras.losses.Reduction.NONE)
        self.margin = margin

    def call(self, y_true, y_pred):
        return triplet_semihard_loss(y_true, y_pred, self.margin)

我不明白call方法发生了什么,它返回基类函数,给出y_truey_pred ndarrays,但它们究竟来自哪里? 根据 Tensorflow 文档指南,子类在模型compile语句中初始化为:

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

然后模型拟合为:

history = model.fit(
    train_dataset,
    epochs=5)

train_dataset结构是一个元组,包含嵌入数据和相应的整数标签,但是子类如何意识到这是要操作的数据? 那么call方法也是隐式调用的吗?

__call__在类的实例被调用时被调用。 y_truey_pred包含模型预测的真实标签和标签。 Tensorflow(tf.keras) 在内部将您提供的标签转换为y_true ,如此处所示并使用model.fit()对数据进行model.fit()
所有tf.keras损失以这种形式来实现,即函数有两个参数y_truey_pred所看到这里

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM