[英]How to access sample weights in a Keras custom loss function supplied by a generator?
我有一个生成器函数,它在某些图像目录上无限循环并输出批量的 3 元组表单
[img1, img2], label, weight
其中img1
和img2
是batch_size x M x N x 3
张量, label
和weight
分别是batch_size
x 1 张量。
在使用fit_generator
训练模型时,我将此生成器提供给fit_generator
函数。
对于这个模型,我有一个自定义的余弦对比损失函数,
def cosine_constrastive_loss(y_true, y_pred):
cosine_distance = 1 - y_pred
margin = 0.9
cdist = y_true * y_pred + (1 - y_true) * keras.backend.maximum(margin - y_pred, 0.0)
return keras.backend.mean(cdist)
从结构上讲,我的模型一切正常。 没有错误,它正在按预期消耗来自生成器的输入和标签。
但是现在我正在寻求直接使用每个批次的权重参数,并根据特定于样本的权cosine_contrastive_loss
内部执行一些自定义逻辑。
如何在执行损失函数时从一批样本的结构中访问此参数?
请注意,由于它是一个无限循环的生成器,因此无法预先计算权重或动态计算它们以将权重归入损失函数或生成它们。
它们必须与正在生成的样本一致生成,并且确实在我的数据生成器中有自定义逻辑,可以在为批处理生成时根据img1
、 img2
和label
属性动态确定权重。
我唯一能想到的是手动训练循环,您可以自己获得重量。
有一个权重张量和一个不可变的批量大小:
weights = K.variable(np.zeros((batch_size,)))
在您的自定义损失中使用它们:
def custom_loss(true, pred):
return someCalculation(true, pred, weights)
对于“生成器”:
for e in range(epochs):
for s in range(steps_per_epoch):
x, y, w = next(generator) #or generator.next(), not sure
K.set_value(weights, w)
model.train_on_batch(x, y)
对于keras.utils.Sequence
:
for e in range(epochs):
for s in range(len(generator)):
x,y,w = generator[s]
K.set_value(weights, w)
model.train_on_batch(x,y)
我知道这个答案不是最优的,因为它不会像fit_generator
那样并行从生成器获取数据。 但这是我能想到的最好的简单解决方案。 Keras 没有公开权重,它们会自动应用在一些隐藏的源代码中。
如果可以从x
和y
计算权重,您可以将此任务委托给损失函数本身。
这有点hacky,但可能有效:
input1 = Input(shape1)
input2 = Input(shape2)
# .... model creation .... #
model = Model([input1, input2], outputs)
让损失可以访问input1
和input2
:
def custom_loss(y_true, y_pred):
w = calculate_weights(input1, input2, y_pred)
# .... rest of the loss .... #
这里的问题是您是否可以根据输入将权重计算为张量。
使用样本权重调用 Keras Tensorflow v2 中的损失函数
output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
https://github.com/keras-team/keras/blob/tf-2/keras/engine/training.py
您可以使用 GradientTape 进行自定义训练,请参阅https://www.tensorflow.org/guide/keras/train_and_evaluate#part_ii_writing_your_own_training_evaluation_loops_from_scratch
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.