![](/img/trans.png)
[英]Tensorflow: Load a .pb file and then save it as a frozen graph issues
[英]Add metadata to tensorflow frozen graph pb
為了分享我們經過訓練的 tensorflow 網絡,我們將圖形凍結到一個.pb
文件中。 我們還創建了一個 xml 文件,其中包含一些元數據,例如輸入張量和輸出張量、要應用的預處理類型、訓練數據信息等。然后通過加載圖形和評估張量等使用 Java 或 C# 提供模型。
為了使共享更容易,我想將此 xml 數據包含在.pb
文件中的某處。 有沒有辦法做到這一點? 一個想法是將它作為 tf.Constant,但我不知道如何將它連接到普通圖。
請注意,這是使用freeze_graph.py
。 新的 SavedModel 格式是否更合適?
首先,是的,您應該使用新的 SavedModel 格式,因為它是未來 TF 團隊支持的格式,並且也適用於 Keras。 您可以向模型添加一個額外的端點,它返回一個帶有 XML 數據字符串的常量張量(如您所述)。
這很好,因為它是密封的——底層的 savemodel 格式並不重要,因為您的元數據保存在計算圖本身中。
請參閱此問題的答案: 使用自定義簽名 defs 保存 TF2 keras 模型。 對於 Keras,該答案並不能 100% 為您提供幫助,因為它無法與 tf.keras.models.load 函數很好地互操作,因為它們將其包裝在tf.Module
。 幸運的是,如果你添加一個 tf.function 裝飾器,在 TF2 中使用tf.keras.Model
正常工作:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
然后您可以按如下方式保存和加載您的模型:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
最后使用model_loaded.get_metadata()
來檢索您的常量元數據張量。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.