[英]How to modify a variable inside the loss function in each epoch during training?
I have a custom loss function.我有一个自定义损失函数。 In each epoch I would like to either keep or throw away my input matrix randomly:在每个时代,我想随机保留或丢弃我的输入矩阵:
import random
from tensorflow.python.keras import backend
def decision(probability):
return random.random() < probability
def my_throw_loss_in1(y_true, y_pred):
if decision(probability=0.5):
keep_mask = tf.ones_like(in1)
total_loss = backend.mean(backend.square(y_true- y_pred)) * keep_mask
print('Input1 is kept')
else:
throw_mask = tf.zeros_like(in1)
total_loss = backend.mean(backend.square(y_true- y_pred)) * throw_mask
print('Input1 is thrown away')
return total_loss
model.compile(loss= [ my_throw_loss_in1],
optimizer='Adam',
metrics=['mae'])
history2 = model.fit([x, y], batch_size=10, epochs=150, validation_split=0.2, shuffle=True)
but this would only set the decision value once and doesn't compile the loss in each epoch.但这只会设置决策值一次,并且不会编译每个时期的损失。 How do I write a loss function that its variable can be modified in each epoch?如何编写一个可以在每个时期修改其变量的损失函数?
Here some thoughts:这里有一些想法:
OR或者
Just change your loss function as follows in order for it to be evaluated when fit(*)
is called:只需按如下方式更改您的损失函数,以便在调用fit(*)
时对其进行评估:
def my_throw_loss_in1(y_true, y_pred):
probability = 0.5
random_uniform = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
condition = tf.less(random_uniform, probability)
mask = tf.cond(condition, lambda: tf.ones_like(y_true), lambda: tf.zeros_like(y_true))
total_loss = tf.keras.backend.mean(tf.keras.backend.square(y_true - y_pred)* mask)
tf.print(mask)
return total_loss
First, a random number is generated and then a condition (random number less than the probability you defined) is created based on this number.首先生成一个随机数,然后根据这个数字创建一个条件(小于你定义的概率的随机数)。 Afterwards, you just use tf.cond
to return tf.ones_like
if your condition is True
, otherwise tf.zeros_like
.之后,如果您的条件为True
,则只需使用tf.cond
返回tf.ones_like
,否则tf.zeros_like
。 Finally, the mask is simply applied to your loss.最后,面具只是简单地应用于你的损失。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.