[英]How to correctly create saved_model.pb?
I have been trying to create a saved_model.pb
file (from .ckpt
, .meta
files) that is needed in order to do inference.我一直在尝试创建一个进行推理所需的saved_model.pb
文件(来自.ckpt
、 .meta
文件)。 I can successfully create a file which contains saved_model.pb
and variables, however when I deploy my script, I get a KeyError
on the expected tensors:我可以成功创建一个包含saved_model.pb
和变量的文件,但是当我部署我的脚本时,我在预期的张量上得到一个KeyError
:
y_probs = [my_predictor._fetch_tensors['y_prob_{}'.format(p)] for p in protocols]
KeyError: 'y_prob_protocol1'
The problem is probably in how I've defined my inputs/outputs (see code at the end) because the feed and fetch tensors are empty as you can see below:问题可能在于我如何定义我的输入/输出(见最后的代码),因为 feed 和 fetch 张量是空的,如下所示:
my_predictor = predictor.from_saved_model('export')
SavedModelPredictor with feed tensors {} and fetch_tensors {}
saver = tf.train.import_meta_graph(opts.model)
builder = tf.saved_model.builder.SavedModelBuilder(opts.out_path)
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, opts.checkpoint)
print("Model restored.")
input_tensor = tf.placeholder(tf.float32, shape=(None,128,128,128,1), name='tensors/component_0')
tensor_1 = tf.placeholder(tf.float32, shape=(None,128,128,128,2), name='pred/Reshape')
tensor_2 = tf.placeholder(tf.float32, shape=(None,128,128,128,3), name='pred1/Reshape')
tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
tensor_info_1 = tf.saved_model.utils.build_tensor_info(tensor_1)
tensor_info_2 = tf.saved_model.utils.build_tensor_info(tensor_2)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_input},
outputs={'y_prob_protocol1': tensor_info_1, 'y_prob_protocol2':tensor_info_2},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
})
builder.save()
Thank you for your help !感谢您的帮助 !
I suspect there might be 2 reasons for this error:我怀疑这个错误可能有两个原因:
The Restored Model (Saved using Check Points) might not be properly linked to the builder.save()
's, saved_model.pb
file.恢复的模型(使用检查点保存)可能没有正确链接到builder.save()
的saved_model.pb
文件。
You have used 2 outputs, tensor_info_1
and tensor_info_2
in the SignatureDef.您在 SignatureDef 中使用了 2 个输出, tensor_info_1
和tensor_info_2
。 But they are not defined (at least in the code shown).但是它们没有定义(至少在显示的代码中)。 By definition, I mean something, like,根据定义,我的意思是,比如,
y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
. y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
。
You can use this simple script to convert from Checkpoints and Meta Files to .pb file.您可以使用这个简单的脚本将检查点和元文件转换为 .pb 文件。 But you must specify the names of the output nodes.但是您必须指定输出节点的名称。
import tensorflow as tf
meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0'] # Output nodes
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,tf.train.latest_checkpoint('.'))
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
This conversion is too much of work.这种转换工作量太大。 Instead of Saving the Model to Check Points
and then trying to convert it to .pb file
, you can Save the Model, Graphs and the SignatureDefs
directly to .pb file
either using SavedModelBuilder
or using export_saved_model
.相反,保存模型的Check Points
,然后试图将其转换为.pb file
,你可以保存模型,图形和SignatureDefs
直接.pb file
或者使用SavedModelBuilder
或使用export_saved_model
。
Example code for Saving a Model using SavedModelBuilder
is given in the below link.以下链接提供了使用SavedModelBuilder
保存模型的示例代码。
It is the official code provided by Google Tensorflow Serving Team and following this code (flow and structure) would be recommended.这是 Google Tensorflow Serving Team 提供的官方代码,建议遵循此代码(流程和结构)。
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.