简体   繁体   English

带有自定义对象的 Keras load_model 无法正常工作

[英]Keras load_model with custom objects doesn't work properly

Setting环境

As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model.正如标题中已经提到的,在尝试加载保存的模型时,我的自定义损失函数出现了问题。 My loss looks as follows:我的损失如下:

def weighted_cross_entropy(weights):

    weights = K.variable(weights)

    def loss(y_true, y_pred):
        y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())

        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

weighted_loss = weighted_cross_entropy([0.1,0.9])

So during training, I used the weighted_loss function as loss function and everything worked well.所以在训练期间,我使用了weighted_loss函数作为损失函数,一切运行良好。 When training is finished I save the model as .h5 file with the standard model.save function from keras API.训练完成后,我使用 keras API 中的标准model.save函数将模型保存为.h5文件。

Problem问题

When I am trying to load the model via当我试图通过加载模型时

model = load_model(path,custom_objects={"weighted_loss":weighted_loss})

I am getting a ValueError telling me that the loss is unknown.我收到一个ValueError ,告诉我损失未知。

Error错误

The error message looks as follows:错误消息如下所示:

File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss

Questions问题

How can I fix this problem?我该如何解决这个问题? May it be possible that the reason for that is my wrapped loss definition?这可能是我的包装损失定义的原因吗? So keras doesn't know, how to handle the weights variable?所以keras不知道,如何处理weights变量?

Your loss function's name is loss (ie def loss(y_true, y_pred): ).您的损失函数的名称是loss (即def loss(y_true, y_pred): )。 Therefore, when loading back the model you need to specify 'loss' as its name:因此,在加载模型时,您需要指定'loss'作为其名称:

model = load_model(path, custom_objects={'loss': weighted_loss})

For full examples demonstrating saving and loading Keras models with custom loss functions or models , please have a look at the following GitHub gist files:有关演示使用自定义损失函数或模型保存和加载 Keras 模型的完整示例,请查看以下 GitHub gist 文件:

Custom loss function defined using a wrapper: https://gist.github.com/ashkan-abbasi66/a81fe4c4d588e2c187180d5bae734fde使用包装器定义的自定义损失函数: https ://gist.github.com/ashkan-abbasi66/a81fe4c4d588e2c187180d5bae734fde

Custom loss function defined by subclassing: https://gist.github.com/ashkan-abbasi66/327efe2dffcf9788847d26de934ef7bd通过子类化定义的自定义损失函数: https ://gist.github.com/ashkan-abbasi66/327efe2dffcf9788847d26de934ef7bd

Custom model: https://gist.github.com/ashkan-abbasi66/d5a525d33600b220fa7b095f7762cb5b自定义模型: https ://gist.github.com/ashkan-abbasi66/d5a525d33600b220fa7b095f7762cb5b

Note: I tested the above examples on Python 3.8 with Tensorflow 2.5.注意:我在 Python 3.8 上使用 Tensorflow 2.5 测试了上述示例。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM