简体   繁体   English

如何在tensorflow中训练后使用模型(保存/加载图)

[英]how to use model after trained in tensorflow (save/load graph)

My tensorflow version is 0.11. 我的张量流版本是0.11。 I want to save a graph after training or save something else which tensorflow can load it. 我希望在训练后保存图形或保存tensorflow可以加载的其他东西。

I/ Using Exporting and Importing a MetaGraph 我/使用导出和导入MetaGraph

I already read this post: Tensorflow: how to save/restore a model? 我已经阅读过这篇文章: Tensorflow:如何保存/恢复模型?

My Save.py file: 我的Save.py文件:

X = tf.placeholder("float", [None, 28, 28, 1], name='X')
Y = tf.placeholder("float", [None, 10], name='Y')

tf.train.Saver()
with tf.Session() as sess:
     ...run something ...
     final_tensor = tf.nn.softmax(py_x, name='final_result')
     tf.add_to_collection("final_tensor", final_tensor)

     predict_op = tf.argmax(py_x, 1)
     tf.add_to_collection("predict_op", predict_op)

saver.save(sess, 'my_project') 

Then I run load.py: 然后我运行load.py:

with tf.Session() as sess:
   new_saver = tf.train.import_meta_graph('my_project.meta')
   new_saver.restore(sess, 'my_project')
   predict_op = tf.get_collection("predict_op")[0]
   for i in range(2):
        test_indices = np.arange(len(teX)) # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                         sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
                                                         "p_keep_conv:0": 1.0,
                                                         "p_keep_hidden:0": 1.0})))

but it return error 但它返回错误

Traceback (most recent call last):
  File "load_05_convolution.py", line 62, in <module>
    "p_keep_hidden:0": 1.0})))
  File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run
    run_metadata_ptr)
  File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (256, 784) for Tensor u'X:0', which has shape '(?, 28, 28, 1)'

I really don't know why? 我真的不知道为什么?

If I add final_tensor = tf.get_collection("final_result")[0] 如果我添加final_tensor = tf.get_collection("final_result")[0]

It return another error: 它返回另一个错误:

Traceback (most recent call last):
  File "load_05_convolution.py", line 46, in <module>
    final_tensor = tf.get_collection("final_result")[0]
IndexError: list index out of range

Is it because tf.add_to_collection only contains only one place holder ? 是因为tf.add_to_collection只包含一个占位符吗?

II/ using tf.train.write_graph II /使用tf.train.write_graph

I add this line to the end of the save.py tf.train.write_graph(graph, 'folder', 'train.pb') 我将此行添加到save.py tf.train.write_graph(graph, 'folder', 'train.pb')

It created file 'train.pb' successfully 它成功创建了文件'train.pb'

My load.py : 我的load.py

with tf.gfile.FastGFile('folder/train.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
  predict_op = sess.graph.get_tensor_by_name('predict_op:0')
  for i in range(2):
        test_indices = np.arange(len(teX)) # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]

        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                         sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
                                                         "p_keep_conv:0": 1.0,
                                                         "p_keep_hidden:0": 1.0})))

Then it return error: 然后它返回错误:

Traceback (most recent call last):
  File "load_05_convolution.py", line 22, in <module>
    graph_def.ParseFromString(f.read())
  File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString
    raise message_mod.DecodeError('Unexpected end-group tag.')
google.protobuf.message.DecodeError: Unexpected end-group tag.

would you mind sharing the standard way, code or tutorial to save/load model ? 你会介意分享标准方式,代码或教程来保存/加载模型吗? I'm really confused. 我真的很困惑。

Your first solution (using the MetaGraph) almost works, but the error arises because you are feeding a batch of flattened MNIST training examples to a tf.placeholder() that expects a batch of MNIST training examples as a 4-D tensor with shape batch_size x height (= 28) x width (= 28) x channels (= 1). 你的第一个解决方案(使用MetaGraph)几乎可以工作,但是错误的产生是因为你正在向一个tf.placeholder()提供一批扁平的 MNIST训练样例,它希望将一批MNIST训练样例作为具有形状batch_size的4-D张量x height (= 28)x width (= 28)x channels (= 1)。 The easiest way to solve this is to reshape your input data. 解决此问题的最简单方法是重塑输入数据。 Instead of this statement: 而不是这个声明:

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                 sess.run(predict_op, feed_dict={
                     "X:0": teX[test_indices],
                     "p_keep_conv:0": 1.0,
                     "p_keep_hidden:0": 1.0})))

...try the following statement, which reshapes your input data appropriately, instead: ...尝试以下语句,它适当地重塑您的输入数据,而不是:

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
                 sess.run(predict_op, feed_dict={
                     "X:0": teX[test_indices].reshape(-1, 28, 28, 1),
                     "p_keep_conv:0": 1.0,
                     "p_keep_hidden:0": 1.0})))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM