繁体   English   中英

将元数据添加到 tensorflow 冻结图 pb

[英]Add metadata to tensorflow frozen graph pb

为了分享我们经过训练的 tensorflow 网络,我们将图形冻结到一个.pb文件中。 我们还创建了一个 xml 文件,其中包含一些元数据,例如输入张量和输出张量、要应用的预处理类型、训练数据信息等。然后通过加载图形和评估张量等使用 Java 或 C# 提供模型。

为了使共享更容易,我想将此 xml 数据包含在.pb文件中的某处。 有没有办法做到这一点? 一个想法是将它作为 tf.Constant,但我不知道如何将它连接到普通图。

请注意,这是使用freeze_graph.py 新的 SavedModel 格式是否更合适?

首先,是的,您应该使用新的 SavedModel 格式,因为它是未来 TF 团队支持的格式,并且也适用于 Keras。 您可以向模型添加一个额外的端点,它返回一个带有 XML 数据字符串的常量张量(如您所述)。

这很好,因为它是密封的——底层的 savemodel 格式并不重要,因为您的元数据保存在计算图本身中。

请参阅此问题的答案: 使用自定义签名 defs 保存 TF2 keras 模型 对于 Keras,该答案并不能 100% 为您提供帮助,因为它无法与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在tf.Module 幸运的是,如果你添加一个 tf.function 装饰器,在 TF2 中使用tf.keras.Model正常工作:

class MyModel(tf.keras.Model):

  def __init__(self, metadata, **kwargs):
    super(MyModel, self).__init__(**kwargs)
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.metadata = tf.constant(metadata)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  @tf.function(input_signature=[])
  def get_metadata(self):
    return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)

然后您可以按如下方式保存和加载您的模型:

tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')

最后使用model_loaded.get_metadata()来检索您的常量元数据张量。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM