[英]Keras load_model with custom objects doesn't work properly
Setting环境
As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model.正如标题中已经提到的,在尝试加载保存的模型时,我的自定义损失函数出现了问题。 My loss looks as follows:
我的损失如下:
def weighted_cross_entropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
loss = y_true * K.log(y_pred) * weights
loss = -K.sum(loss, -1)
return loss
return loss
weighted_loss = weighted_cross_entropy([0.1,0.9])
So during training, I used the weighted_loss
function as loss function and everything worked well.所以在训练期间,我使用了
weighted_loss
函数作为损失函数,一切运行良好。 When training is finished I save the model as .h5
file with the standard model.save
function from keras API.训练完成后,我使用 keras API 中的标准
model.save
函数将模型保存为.h5
文件。
Problem问题
When I am trying to load the model via当我试图通过加载模型时
model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
I am getting a ValueError
telling me that the loss is unknown.我收到一个
ValueError
,告诉我损失未知。
Error错误
The error message looks as follows:错误消息如下所示:
File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss
Questions问题
How can I fix this problem?我该如何解决这个问题? May it be possible that the reason for that is my wrapped loss definition?
这可能是我的包装损失定义的原因吗? So
keras
doesn't know, how to handle the weights
variable?所以
keras
不知道,如何处理weights
变量?
Your loss function's name is loss
(ie def loss(y_true, y_pred):
).您的损失函数的名称是
loss
(即def loss(y_true, y_pred):
)。 Therefore, when loading back the model you need to specify 'loss'
as its name:因此,在加载模型时,您需要指定
'loss'
作为其名称:
model = load_model(path, custom_objects={'loss': weighted_loss})
For full examples demonstrating saving and loading Keras models with custom loss functions or models , please have a look at the following GitHub gist files:有关演示使用自定义损失函数或模型保存和加载 Keras 模型的完整示例,请查看以下 GitHub gist 文件:
Custom loss function defined using a wrapper: https://gist.github.com/ashkan-abbasi66/a81fe4c4d588e2c187180d5bae734fde使用包装器定义的自定义损失函数: https ://gist.github.com/ashkan-abbasi66/a81fe4c4d588e2c187180d5bae734fde
Custom loss function defined by subclassing: https://gist.github.com/ashkan-abbasi66/327efe2dffcf9788847d26de934ef7bd通过子类化定义的自定义损失函数: https ://gist.github.com/ashkan-abbasi66/327efe2dffcf9788847d26de934ef7bd
Custom model: https://gist.github.com/ashkan-abbasi66/d5a525d33600b220fa7b095f7762cb5b自定义模型: https ://gist.github.com/ashkan-abbasi66/d5a525d33600b220fa7b095f7762cb5b
Note: I tested the above examples on Python 3.8 with Tensorflow 2.5.注意:我在 Python 3.8 上使用 Tensorflow 2.5 测试了上述示例。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.