繁体   English   中英

如何正确创建saved_model.pb?

[英]How to correctly create saved_model.pb?

我一直在尝试创建一个进行推理所需的saved_model.pb文件(来自.ckpt.meta文件)。 我可以成功创建一个包含saved_model.pb和变量的文件,但是当我部署我的脚本时,我在预期的张量上得到一个KeyError

y_probs = [my_predictor._fetch_tensors['y_prob_{}'.format(p)] for p in protocols]

KeyError: 'y_prob_protocol1'

问题可能在于我如何定义我的输入/输出(见最后的代码),因为 feed 和 fetch 张量是空的,如下所示:

my_predictor = predictor.from_saved_model('export')

SavedModelPredictor with feed tensors {} and fetch_tensors {}
saver = tf.train.import_meta_graph(opts.model)
builder = tf.saved_model.builder.SavedModelBuilder(opts.out_path)

with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, opts.checkpoint)
    print("Model restored.")

    input_tensor = tf.placeholder(tf.float32, shape=(None,128,128,128,1), name='tensors/component_0')
    tensor_1 = tf.placeholder(tf.float32, shape=(None,128,128,128,2), name='pred/Reshape')
    tensor_2 = tf.placeholder(tf.float32, shape=(None,128,128,128,3), name='pred1/Reshape')


    tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
    tensor_info_1 = tf.saved_model.utils.build_tensor_info(tensor_1)
    tensor_info_2 = tf.saved_model.utils.build_tensor_info(tensor_2)


    prediction_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'x': tensor_info_input},
            outputs={'y_prob_protocol1': tensor_info_1, 'y_prob_protocol2':tensor_info_2},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'predict_images':
                prediction_signature,
        })

    builder.save()   

感谢您的帮助 !

我怀疑这个错误可能有两个原因:

  1. 恢复的模型(使用检查点保存)可能没有正确链接到builder.save()saved_model.pb文件。

  2. 您在 SignatureDef 中使用了 2 个输出, tensor_info_1tensor_info_2 但是它们没有定义(至少在显示的代码中)。 根据定义,我的意思是,比如,

    y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')

您可以使用这个简单的脚本将检查点和元文件转换为 .pb 文件。 但是您必须指定输出节点的名称。

import tensorflow as tf

meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:

    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('.'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

这种转换工作量太大。 相反,保存模型的Check Points ,然后试图将其转换为.pb file ,你可以保存模型,图形和SignatureDefs直接.pb file或者使用SavedModelBuilder或使用export_saved_model

以下链接提供了使用SavedModelBuilder保存模型的示例代码。

这是 Google Tensorflow Serving Team 提供的官方代码,建议遵循此代码(流程和结构)。

https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM