繁体   English   中英

由于使用“lambda”,无法加载已保存的 Keras model

[英]Cannot load saved Keras model due to use of “lambda”

我有一个简单的 Keras 网络,它使用定义为 lambda 的自定义激活 function:

from tensorflow.keras.activations import relu
lrelu = lambda x: relu( x, alpha=0.01 )
model = Sequential
model.add(Dense( 10, activation=lrelu, input_dim=12 ))
...

它可以很好地编译、训练、测试(代码省略),我可以使用model.save( 'model.h5' )来保存它。 但是当我尝试使用loaded = tf.keras.models.load_model( 'model.h5', custom_objects={'lrelu': lrelu})加载它时,尽管定义lrelu ,如上所示,它抱怨:

ValueError: Unknown activation function:<lambda>

等一下:lambda 不是lambda关键字吗? 我不打算重新定义 python 所以我可以加载 model - 它会在哪里结束? 我该如何克服呢? 我需要指定什么作为我的custom_objects

根据TF Keras 指南,使用自定义对象和函数进行保存和加载...

自定义函数(例如激活丢失或初始化)不需要 get_config 方法。 只要注册为自定义 object,function 名称就足以加载。

在我看来,这正是我所做的。 难道这仅适用于用def定义的函数而不适用于 lambda 函数吗?

Lambda 没有有效的名称属性 Keras 可以自省,因此在序列化过程中会混淆。 请改用命名的 function。

from tensorflow.keras.activations import relu

def lrelu(x):
   return relu(x, alpha=0.01)

model = Sequential()
model.add(Dense( 10, activation=lrelu, input_dim=12 ))

以机智:

>>> lrelu1 = lambda x: 0
>>> def lrelu2(x):
...   return 0
...
>>> lrelu1.__name__
'<lambda>'
>>> lrelu2.__name__
'lrelu2'
>>>

这是另一种包装激活 function 的方法

model = Sequential()
model.add(Dense( 10, input_dim=12 ))
model.add(Lambda( lambda x: tf.keras.activations.relu( x, alpha=0.01 ) ))

这与执行 model.add(Activation('...')) 的概念相同,但使用自定义修改激活

用于保存和加载:

model.save( 'model.h5' )
loaded = tf.keras.models.load_model( 'model.h5' )

我使用https://colab.research.google.com/drive/1K-4_nt66AH5PQDv9Fn-l69-eu5S6Y5EU?usp=sharing来保存和加载 model 没有问题

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM