簡體   English   中英

如何從 .h5 文件正確加載帶有自定義層的 Keras 模型?

[英]How to load the Keras model with custom layers from .h5 file correctly?

我構建了一個帶有自定義層的 Keras 模型,並通過回調ModelCheckPoint將其保存到.h5文件中。 當我在訓練后嘗試加載此模型時,出現以下錯誤消息:

 __init__() missing 1 required positional argument: 'pool_size'

這是自定義層及其__init__方法的定義:

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = pool_size
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

這就是我將此層添加到我的模型的方式:

x = MyMeanPooling(globalvars.pool_size)(x)

這是我加載模型的方式:

from keras.models import load_model

model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})

這些是完整的錯誤消息:

Traceback (most recent call last):
  File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module>
    model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config
    process_layer(layer_data)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer
    custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool_size'

其實我不認為你可以加載這個模型。

最可能的問題是您沒有在層中實現get_config()方法。 此方法返回應保存的配置值字典:

def get_config(self):
    config = {'pool_size': self.pool_size,
              'axis': self.axis}
    base_config = super(MyMeanPooling, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

將此方法添加到您的層后,您必須重新訓練模型,因為之前保存的模型沒有保存該層的配置。 這就是您無法加載它的原因,它需要在進行此更改后重新訓練。

來自“LiamHe 在 2017 年 9 月 27 日發表評論”對以下問題的回答: https : //github.com/keras-team/keras/issues/4871

我今天遇到了同樣的問題:** TypeError: init() missing 1 required positional arguments**。 這是我解決問題的方法:(Keras 2.0.2)

  1. 使用一些默認值給出圖層的位置參數
  2. 使用類似的東西將 get_config 函數覆蓋到層
def get_config(self):
    config = super().get_config()
    config['pool_size'] = # say self._pool_size  if you store the argument in __init__
    return config
  1. 加載模型時將圖層類添加到 custom_objects。

如果您沒有足夠的時間以 Matias Valdenegro 的求解方式重新訓練模型。 您可以設置pool_size的類MyMeanPooling類似下面的代碼的默認值。 注意pool_size的值要與訓練模型時的值保持一致。 然后就可以加載模型了。

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = 2  # The value should be consistent with the value while training the model
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

參考: https : //www.jianshu.com/p/e97112c34e43

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM