![](/img/trans.png)
[英]How tensorflow graph regularization (NSL) affects triplet semihard loss (TFA)
[英]How does "tfa.losses.triplet_semihard_loss" get called?
在 Tensorflow 插件中,有两次提到三元组损失,一个是基类tfa.losses.triplet_semihard_loss
,另一个是tfa.losses.TripletSemiHardLoss
,它是由用户初始化的子类,进而隐式调用基类。 在属于子类的这段代码中:
def __init__(self, margin=1.0, name=None):
super(TripletSemiHardLoss, self).__init__(
name=name, reduction=tf.keras.losses.Reduction.NONE)
self.margin = margin
def call(self, y_true, y_pred):
return triplet_semihard_loss(y_true, y_pred, self.margin)
我不明白call
方法发生了什么,它返回基类函数,给出y_true
和y_pred
ndarrays,但它们究竟来自哪里? 根据 Tensorflow 文档指南,子类在模型compile
语句中初始化为:
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss())
然后模型拟合为:
history = model.fit(
train_dataset,
epochs=5)
train_dataset
结构是一个元组,包含嵌入数据和相应的整数标签,但是子类如何意识到这是要操作的数据? 那么call
方法也是隐式调用的吗?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.