簡體   English   中英

Tensorflow MNIST示例:從SavedModel進行預測的代碼

[英]Tensorflow MNIST Sample: Code to Predict from SavedModel

我正在根據本文使用該示例構建CNN: https//www.tensorflow.org/tutorials/layers

但是,我無法通過輸入樣本圖像來找到要預測的樣本。 在這里的任何幫助將不勝感激。

下面是我嘗試過的,但找不到輸出張量名稱

img = <load from file>
sess = tf.Session()
saver = tf.train.import_meta_graph('/tmp/mnist_convnet_model/model.ckpt-2000.meta')
saver.restore(sess, tf.train.latest_checkpoint('/tmp/mnist_convnet_model/'))

input_place_holder = sess.graph.get_tensor_by_name("enqueue_input/Placeholder:0")
out_put = <not sure what the tensor output name in the graph>
current_input = img

result = sess.run(out_put, feed_dict={input_place_holder: current_input})
print(result)

您可以使用inspect_checkpoint在Tensorflow工具找到檢查點文件中的張量。

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name="tmp/mnist_convnet_model/model.ckpt-2000.meta", tensor_name='')

在Tensorflows編程指南中有關於如何保存和還原的很好的說明。 這是從后面的鏈接中得到啟發的一個小例子。 只要確保./tmp目錄存在

import tensorflow as tf
# Create some variables.
variable = tf.get_variable("variable_1", shape=[3], initializer=tf.zeros_initializer)
inc_v1=variable.assign(variable + 1)

# Operation to initialize variables if we do not restore from checkpoint
init_op = tf.global_variables_initializer()

# Create the saver
saver = tf.train.Saver()
with tf.Session() as sess:
    # Setting to decide wether or not to restore
    DO_RESTORE=True
    # Where to save the data file
    save_path="./tmp/model.ckpt"
    if DO_RESTORE:
        # If we want to restore, load the variables from the saved file
        saver.restore(sess, save_path)
    else:
        # If we don't want to restore, then initialize variables
        # using their specified initializers.
        sess.run(init_op)

    # Print the initial values of variable
    initial_var_value=sess.run(variable)
    print("Initial:", initial_var_value)
    # Do some work with the model.
    incremented=sess.run(inc_v1)
    print("Incremented:", incremented)
    # Save the variables to disk.
    save_path = saver.save(sess, save_path)
    print("Model saved in path: %s" % save_path)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM