[英]How do you use Tensorflow Keras Custom Objects with tf.saved_model.Asset?
我有一个自定义的 Keras 层,它从泡菜文件中读取以初始化一些权重,我希望能够在其上使用tf.keras.utils.register_keras_serializable()
。 问题是我的__init__
function 采用了 pickle 文件的路径,当层再次反序列化时,该路径可能不可用。 Keras资产理论上应该使层更便携,但我无法弄清楚如何让它与层的get_config()
一起工作。
我的代码的准系统版本:
@tf.keras.utils.register_keras_serializable()
class AssetLayer(tf.keras.layers.Layer):
def __init__(self, asset_path, **kwargs):
super().__init__(**kwargs)
self.asset_path = asset_path
self.asset = tf.saved_model.Asset(asset_path)
data = tf.io.read_file(self.asset)
# do something with data
def get_config(self):
return {
**super().get_config(),
"asset_path": self.asset_path,
}
def call(self, arg):
# arbitrary call function
return arg
如果使用该层的 model 是使用tf.keras.models.load_model()
加载的,Keras 将调用get_config()
以使用保存的asset_path
重新初始化该层,该路径在反序列化时可能未指向正确的位置。 理想情况下,它会指向已保存资产的路径,但我不知道如何让它这样做。
例如,我试过这段代码
!echo abcd > file.txt
model = tf.keras.Sequential([AssetLayer("file.txt")])
model(tf.ones(3))
model.save("test")
# reloading
!rm file.txt
reloaded_model = tf.keras.models.load_model("test")
这给了我一个错误,说找不到file.txt
。
我也试过完全删除get_config()
function。 这使得图层可以成功重新加载,同时保留对asset
变量的访问权限,但无法访问图层中的其他属性,例如self.asset_path
。 这不是调试目的的理想选择,所以我想知道是否有更好的方法。
我目前正在使用 Tensorflow 2.5.0`
编辑代码:在此部分之前,代码很好。问题正在复制,因为
!rm file.txt
(所以我放在最后了)
!echo abcd > file.txt
model = tf.keras.Sequential([AssetLayer("file.txt")])
model(tf.ones(3))
model.save("./content/sample_data/test.h5")
# reloading
reloaded_model = tf.keras.models.load_model("/content/content/sample_data/test.h5")
reloaded_model.summary()
!rm file.txt
参考: https://www.tensorflow.org/guide/keras/save_and_serialize
似乎“tf.saved_model.Asset”不支持“tf.keras.models.load_model”尝试使用 tf.saved_model.save / tf.saved_model.load
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.