簡體   English   中英

如何在TensorFlow中使用Estimator將模型存儲在`.pb`文件中?

[英]How to store model in `.pb` file with Estimator in TensorFlow?

我使用TensorFlow的估算器訓練了模型。 似乎應該使用export_savedmodel來制作.pb文件,但是我真的不知道如何構造serving_input_receiver_fn 有人有想法嗎? 示例代碼受到歡迎。

額外問題:

  1. .pb是我要重新加載模型時唯一需要的文件嗎? Variable不必要嗎?

  2. 與使用adam優化程序的.ckpt相比, .pb將減少多少模型文件大小?

如果使用的是tf.estimator.Estimator ,則可以使用freeze_graph.py.ckpt + .pbtxt生成.pb ,然后在model_dir找到這兩個文件

python freeze_graph.py \
    --input_graph=graph.pbtxt \
    --input_checkpoint=model.ckpt-308 \
    --output_graph=output_graph.pb
    --output_node_names=<output_node>
  1. .pb是我要重新加載模型時唯一需要的文件嗎? 變量不必要嗎?

是的 ,您還必須知道您是模型的輸入節點和輸出節點名稱。 然后使用import_graph_def加載.pb文件,並使用get_operation_by_name獲取輸入和輸出操作

  1. 與使用adam優化程序的.ckpt相比,.pb將減少多少模型文件大小?

.pb文件不是壓縮的.ckpt文件,因此沒有“壓縮率”。

但是,有一種方法可以優化.pb文件以進行推理,並且此優化可能會減小文件大小,因為它會刪除僅訓練操作的圖形部分(請參見此處的完整說明)。

[評論]如何獲取輸入和輸出節點名稱?

您可以使用op name參數設置輸入和輸出節點名稱。

要在.pbtxt文件中列出節點名稱,請使用以下腳本。

import tensorflow as tf
from google.protobuf import text_format

with open('graph.pbtxt') as f:
    graph_def = text_format.Parse(f.read(), tf.GraphDef())

print [n.name for n in graph_def.node]

[評論]我發現有一個tf.estimator.Estimator.export_savedmodel(),是直接將模型存儲在.pb中的函數嗎? 我正在努力解決它的參數serving_input_receiver_fn。 有任何想法嗎?

export_savedmodel()生成一個SavedModel ,這是TensorFlow模型通用序列化格式。 它應該包含與TensorFlow Serving API配合所需的一切

serving_input_receiver_fn()是生成SavedModel必須提供的SavedModel ,它通過在圖上添加占位符來確定模型的輸入簽名。

從文檔

此功能具有以下目的:

  • 要將占位符添加到服務系統將隨推理請求一起提供的圖上。
  • 添加將數據從輸入格式轉換為模型所需的特征張量所需的任何其他操作。

如果您以序列化的tf.Examples (這是一種典型模式)的形式接收推理請求,則可以使用doc中提供的示例。

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
  """An input receiver that expects a serialized tf.Example."""
  serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

[評論]有什么想法在'.pb'中列出節點名稱嗎?

這取決於它是如何生成的。

如果是SavedModel則使用:

import tensorflow as tf

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        './saved_models/1519232535')
    print [n.name for n in meta_graph_def.graph_def.node]

如果它是一個MetaGraph則使用:

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    with gfile.FastGFile('model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        print [n.name for n in graph_def.node]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM