简体   繁体   English

使用自定义层保存 Tensorflow 模型

[英]Saving Tensorflow models with custom layers

I read through the documentation, but something wasn't clear for me: if I coded a custom layer and then used it in a model, can I just save the model as SavedModel and the custom layer automatically goes within it or do I have to save the custom layer too?我通读了文档,但对我来说有些不清楚:如果我编写了一个自定义层然后在模型中使用它,我可以将模型保存为 SavedModel 并且自定义层自动进入其中还是我必须这样做也保存自定义图层? I tried saving just the model in H5 format and not the custom layer.我尝试仅以 H5 格式保存模型而不是自定义图层。 When I tried to load the model, I had an error on the custom layer not being recognized or something like this.当我尝试加载模型时,我在自定义图层上出现错误,无法识别或类似的错误。 Reading through the documentation, I saw that saving to custom objects to H5 format is a bit more involved.通读文档,我发现将自定义对象保存为 H5 格式要复杂一些。 But how does it work with SavedModels?但是它如何与 SavedModels 一起工作?

If I understand your question, you should simply use tf.keras.models.save_model(<model_object>,'file_name',save_format='tf') .如果我理解你的问题,你应该简单地使用tf.keras.models.save_model(<model_object>,'file_name',save_format='tf')

My understanding is that the 'tf' format automatically saves the custom layers, so loading doesn't require all libraries be present.我的理解是“tf”格式会自动保存自定义图层,因此加载不需要所有库都存在。 This doesn't extend to all custom objects, but I don't know where that distinction lies.这不会扩展到所有自定义对象,但我不知道区别在哪里。 If you want to load a model that uses non-layer custom objects you have to use the custom_objects parameter in tf.keras.models.load_model() .如果要加载使用非层自定义对象的模型,则必须使用tf.keras.models.load_model()custom_objects参数。 This is only necessary if you want to train immediately after loading.仅当您想在加载后立即训练时才需要这样做。 If you don't intend to train the model immediately, you should be able to forego custom_objects and just set compile=False in load_model .如果您不打算立即训练模型,您应该可以放弃custom_objects并在load_model设置compile=False

If you want to use the 'h5' format, you supposedly have to have all libraries/modules/packages that the custom object utilizes present and loaded in order for the 'h5' load to work.如果你想使用“h5”格式,你应该让自定义对象使用的所有库/模块/包都存在并加载,以便“h5”加载工作。 I know I've done this with an intializer before.我知道我以前用初始化器做过这件事。 This might not matter for layers, but I assume that it does.这对于图层可能无关紧要,但我认为它确实如此。

You also need to implement get_config() and save_config() functions in the custom object definition in order for 'h5' to save and load properly.您还需要在自定义对象定义中实现get_config()save_config()函数,以便“h5”正确保存和加载。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM