[英]How to add example based parameter to custom keras loss function?
I want to have custom loss function in keras, which has a parameter that is different for each training example.我想在 keras 中使用自定义损失函数,它的参数对于每个训练示例都不同。
from keras import backend as K
def my_mse_loss_b(b):
def mseb(y_true, y_pred):
return K.mean(K.square(y_pred - y_true)) + b
return mseb
I read here that y_true and y_pred are always passed to the loss function so you need to create wrapper function.我在这里读到 y_true 和 y_pred 总是传递给损失函数,因此您需要创建包装函数。
model.compile(loss=my_mse_loss_b(df.iloc[:,2]), optimizer='adam', metrics=['accuracy'])
The problem is that when I fit the model there is an error as the function assumes the passed parameters will be as long as the batch.问题是当我拟合模型时会出现错误,因为函数假定传递的参数与批次一样长。 I on the other hand want each example to have there own parameter.
另一方面,我希望每个示例都有自己的参数。
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [20] vs. [10000]
[[node gradients/loss_2/dense_3_loss/mseb/weighted_loss/mul_grad/BroadcastGradientArgs (defined at C:\Users\flis1\Miniconda3\envs\Automate\lib\site-packages\tensorflow_core\python\framework\ops.py:1751) ]] [Op:__inference_keras_scratch_graph_1129]
Function call stack:
keras_scratch_graph
Incompatible shapes it says.它说不兼容的形状。 20 is the batch size and 10000 is the size of my train dataset and the size of all the parameters.
20 是批量大小,10000 是我的训练数据集的大小和所有参数的大小。
I can fit the model if I the parameter I add is the size of the batch, but as I said I want the parameter to be passed on an example basis.如果我添加的参数是批次的大小,我可以拟合模型,但正如我所说,我希望参数以示例为基础传递。
In your case, because your parameter b
is tightly coupled to its training example, it would make sense to make it part of the ground truth.在您的情况下,由于您的参数
b
与其训练示例紧密耦合,因此将其作为基本事实的一部分是有意义的。 You could rewrite your loss function like the following:你可以像下面这样重写你的损失函数:
def mseb(y_true, y_pred):
y_t, b = y_true[0], y_true[1]
return K.mean(K.square(y_pred - y_t)) + b
and then train your model with然后训练你的模型
model.compile(loss=mseb)
b = df.iloc[:,2]
model.fit(X,(y,b))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.