![](/img/trans.png)
[英]How to convert .pb file to .h5. (Tensorflow model to keras)
[英]How to store model in `.pb` file with Estimator in TensorFlow?
我使用TensorFlow的估算器訓練了模型。 似乎應該使用export_savedmodel
來制作.pb
文件,但是我真的不知道如何構造serving_input_receiver_fn
。 有人有想法嗎? 示例代碼受到歡迎。
額外問題:
.pb
是我要重新加載模型時唯一需要的文件嗎? Variable
不必要嗎?
與使用adam優化程序的.ckpt
相比, .pb
將減少多少模型文件大小?
如果使用的是tf.estimator.Estimator
,則可以使用freeze_graph.py
從.ckpt
+ .pbtxt
生成.pb
,然后在model_dir
找到這兩個文件
python freeze_graph.py \
--input_graph=graph.pbtxt \
--input_checkpoint=model.ckpt-308 \
--output_graph=output_graph.pb
--output_node_names=<output_node>
- .pb是我要重新加載模型時唯一需要的文件嗎? 變量不必要嗎?
是的 ,您還必須知道您是模型的輸入節點和輸出節點名稱。 然后使用import_graph_def
加載.pb文件,並使用get_operation_by_name
獲取輸入和輸出操作
- 與使用adam優化程序的.ckpt相比,.pb將減少多少模型文件大小?
.pb文件不是壓縮的.ckpt文件,因此沒有“壓縮率”。
但是,有一種方法可以優化.pb文件以進行推理,並且此優化可能會減小文件大小,因為它會刪除僅訓練操作的圖形部分(請參見此處的完整說明)。
[評論]如何獲取輸入和輸出節點名稱?
您可以使用op name
參數設置輸入和輸出節點名稱。
要在.pbtxt
文件中列出節點名稱,請使用以下腳本。
import tensorflow as tf
from google.protobuf import text_format
with open('graph.pbtxt') as f:
graph_def = text_format.Parse(f.read(), tf.GraphDef())
print [n.name for n in graph_def.node]
[評論]我發現有一個tf.estimator.Estimator.export_savedmodel(),是直接將模型存儲在.pb中的函數嗎? 我正在努力解決它的參數serving_input_receiver_fn。 有任何想法嗎?
export_savedmodel()
生成一個SavedModel
,這是TensorFlow模型的通用序列化格式。 它應該包含與TensorFlow Serving API配合所需的一切
serving_input_receiver_fn()
是生成SavedModel
必須提供的SavedModel
,它通過在圖上添加占位符來確定模型的輸入簽名。
從文檔
此功能具有以下目的:
- 要將占位符添加到服務系統將隨推理請求一起提供的圖上。
- 添加將數據從輸入格式轉換為模型所需的特征張量所需的任何其他操作。
如果您以序列化的tf.Examples
(這是一種典型模式)的形式接收推理請求,則可以使用doc中提供的示例。
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
[評論]有什么想法在'.pb'中列出節點名稱嗎?
這取決於它是如何生成的。
如果是SavedModel
則使用:
import tensorflow as tf
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
'./saved_models/1519232535')
print [n.name for n in meta_graph_def.graph_def.node]
如果它是一個MetaGraph
則使用:
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print [n.name for n in graph_def.node]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.