![](/img/trans.png)
[英]Tensorflow: Load a .pb file and then save it as a frozen graph issues
[英]Add metadata to tensorflow frozen graph pb
为了分享我们经过训练的 tensorflow 网络,我们将图形冻结到一个.pb
文件中。 我们还创建了一个 xml 文件,其中包含一些元数据,例如输入张量和输出张量、要应用的预处理类型、训练数据信息等。然后通过加载图形和评估张量等使用 Java 或 C# 提供模型。
为了使共享更容易,我想将此 xml 数据包含在.pb
文件中的某处。 有没有办法做到这一点? 一个想法是将它作为 tf.Constant,但我不知道如何将它连接到普通图。
请注意,这是使用freeze_graph.py
。 新的 SavedModel 格式是否更合适?
首先,是的,您应该使用新的 SavedModel 格式,因为它是未来 TF 团队支持的格式,并且也适用于 Keras。 您可以向模型添加一个额外的端点,它返回一个带有 XML 数据字符串的常量张量(如您所述)。
这很好,因为它是密封的——底层的 savemodel 格式并不重要,因为您的元数据保存在计算图本身中。
请参阅此问题的答案: 使用自定义签名 defs 保存 TF2 keras 模型。 对于 Keras,该答案并不能 100% 为您提供帮助,因为它无法与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在tf.Module
。 幸运的是,如果你添加一个 tf.function 装饰器,在 TF2 中使用tf.keras.Model
正常工作:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
然后您可以按如下方式保存和加载您的模型:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
最后使用model_loaded.get_metadata()
来检索您的常量元数据张量。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.