简体   繁体   English

如何从 .h5 文件正确加载带有自定义层的 Keras 模型?

[英]How to load the Keras model with custom layers from .h5 file correctly?

I built a Keras model with a custom layers, and it was saved to a .h5 file by the callback ModelCheckPoint .我构建了一个带有自定义层的 Keras 模型,并通过回调ModelCheckPoint将其保存到.h5文件中。 When I tried to load this model after the training, the error message below showed up:当我在训练后尝试加载此模型时,出现以下错误消息:

 __init__() missing 1 required positional argument: 'pool_size'

This is the definition of the custom layer and its __init__ method:这是自定义层及其__init__方法的定义:

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = pool_size
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

This is how I add this layer to my model:这就是我将此层添加到我的模型的方式:

x = MyMeanPooling(globalvars.pool_size)(x)

This is how I load the model:这是我加载模型的方式:

from keras.models import load_model

model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})

These are the full error messages:这些是完整的错误消息:

Traceback (most recent call last):
  File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module>
    model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config
    process_layer(layer_data)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer
    custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool_size'

Actually I don't think you can load this model.其实我不认为你可以加载这个模型。

The most likely issue is that you did not implement the get_config() method in your layer.最可能的问题是您没有在层中实现get_config()方法。 This method returns a dictionary of configuration values that should be saved:此方法返回应保存的配置值字典:

def get_config(self):
    config = {'pool_size': self.pool_size,
              'axis': self.axis}
    base_config = super(MyMeanPooling, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

You have to retrain the model after adding this method to your layer, as the previously saved model does not have the configuration for this layer saved into it.将此方法添加到您的层后,您必须重新训练模型,因为之前保存的模型没有保存该层的配置。 This is why you cannot load it, it requires retraining after making this change.这就是您无法加载它的原因,它需要在进行此更改后重新训练。

From the answer of "LiamHe commented on Sep 27, 2017" on the following issue: https://github.com/keras-team/keras/issues/4871 .来自“LiamHe 在 2017 年 9 月 27 日发表评论”对以下问题的回答: https : //github.com/keras-team/keras/issues/4871

I met the same problem today : ** TypeError: init() missing 1 required positional arguments**.我今天遇到了同样的问题:** TypeError: init() missing 1 required positional arguments**。 Here is how I solve the problem : (Keras 2.0.2)这是我解决问题的方法:(Keras 2.0.2)

  1. Give the positional arguments of the layer with some default values使用一些默认值给出图层的位置参数
  2. Override get_config function to the layer with some thing like使用类似的东西将 get_config 函数覆盖到层
def get_config(self):
    config = super().get_config()
    config['pool_size'] = # say self._pool_size  if you store the argument in __init__
    return config
  1. Add layer class to custom_objects when you are loading model.加载模型时将图层类添加到 custom_objects。

If you don't have enough time to retrain the model in the solution way of Matias Valdenegro.如果您没有足够的时间以 Matias Valdenegro 的求解方式重新训练模型。 You can set the default value of pool_size in class MyMeanPooling like the following code.您可以设置pool_size的类MyMeanPooling类似下面的代码的默认值。 Note that the value of pool_size should be consistent with the value while training the model.注意pool_size的值要与训练模型时的值保持一致。 Then you can load the model.然后就可以加载模型了。

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = 2  # The value should be consistent with the value while training the model
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

ref: https://www.jianshu.com/p/e97112c34e43参考: https : //www.jianshu.com/p/e97112c34e43

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM