[英]Keras: passing the model object as an argument to the loss function
This post almost does what I want.这篇文章几乎完成了我想要的。 In a nutshell, the suggested solution is:简而言之,建议的解决方案是:
def custom_loss(y_true, y_pred):
# Your model exists in global scope global e
# Get the layers of your model
layers = [l for l in e.layers]
# Construct a graph to evaluate your other model on y_pred
eval_pred = y_pred
for i in range(len(layers)):
eval_pred = layers[i](eval_pred)
# Construct a graph to evaluate your other model on y_true
eval_true = y_true
for i in range(len(layers)):
eval_true = layers[i](eval_true)
# Now do what you wanted to do with outputs.
# Note that we are not returning the values, but a tensor.
return K.mean(K.square(eval_pred - eval_true), axis=-1)
In the function above, e
is a global argument, which is the model itself, and the custom loss function uses the model (which is global) without requiring the user to pass in the model.在上面的函数中, e
是一个全局参数,即模型本身,自定义损失函数使用模型(即全局)而不需要用户传入模型。 I'm not a big fan of global arguments.我不是全球争论的忠实粉丝。 Is there a way to construct a custom_loss function such that it takes in the model object itself without using a global argument.有没有办法构造一个 custom_loss 函数,使其在不使用全局参数的情况下接收模型对象本身。 For example, can I create a function custom_loss(y_true, y_pred, e)
and delete the line global e
, such that I can pass my custom_loss
as a loss function of a model?例如,我可以创建一个函数custom_loss(y_true, y_pred, e)
并删除global e
行,这样我就可以将custom_loss
作为模型的损失函数传递吗?
Keras API does not support that. Keras API 不支持。 As the documentation states, loss functions take exactly two arguments: y_true
and y_pred
.正如文档所述,损失函数采用两个参数: y_true
和y_pred
。
If you what such a feature, you have to modify Keras itself.如果你有什么这样的功能,你就得修改 Keras 本身。 Take a look at:看一眼:
compile
function in keras/engine/training.py
keras/engine/training.py
的compile
函数weighted_masked_objective
function in keras/engine/training_utils.py
keras/engine/training_utils.py
的weighted_masked_objective
函数
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.