簡體   English   中英

將元數據添加到 tensorflow 凍結圖 pb

[英]Add metadata to tensorflow frozen graph pb

為了分享我們經過訓練的 tensorflow 網絡,我們將圖形凍結到一個.pb文件中。 我們還創建了一個 xml 文件,其中包含一些元數據,例如輸入張量和輸出張量、要應用的預處理類型、訓練數據信息等。然后通過加載圖形和評估張量等使用 Java 或 C# 提供模型。

為了使共享更容易,我想將此 xml 數據包含在.pb文件中的某處。 有沒有辦法做到這一點? 一個想法是將它作為 tf.Constant,但我不知道如何將它連接到普通圖。

請注意,這是使用freeze_graph.py 新的 SavedModel 格式是否更合適?

首先,是的,您應該使用新的 SavedModel 格式,因為它是未來 TF 團隊支持的格式,並且也適用於 Keras。 您可以向模型添加一個額外的端點,它返回一個帶有 XML 數據字符串的常量張量(如您所述)。

這很好,因為它是密封的——底層的 savemodel 格式並不重要,因為您的元數據保存在計算圖本身中。

請參閱此問題的答案: 使用自定義簽名 defs 保存 TF2 keras 模型 對於 Keras,該答案並不能 100% 為您提供幫助,因為它無法與 tf.keras.models.load 函數很好地互操作,因為它們將其包裝在tf.Module 幸運的是,如果你添加一個 tf.function 裝飾器,在 TF2 中使用tf.keras.Model正常工作:

class MyModel(tf.keras.Model):

  def __init__(self, metadata, **kwargs):
    super(MyModel, self).__init__(**kwargs)
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.metadata = tf.constant(metadata)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  @tf.function(input_signature=[])
  def get_metadata(self):
    return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)

然后您可以按如下方式保存和加載您的模型:

tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')

最后使用model_loaded.get_metadata()來檢索您的常量元數據張量。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM