简体   繁体   English

tf.train.Saver如何工作?

[英]How does tf.train.Saver work exactly?

I am a bit confused by how tf.train.Saver() works. 我对tf.train.Saver()的工作方式有些困惑。 I have the following code to save only trainable variables: 我有以下代码仅保存可训练的变量:

import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver(tf.trainable_variables())
print([x.name for x in tf.trainable_variables()])
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "./model.ckpt")
  print("Model saved in file: %s" % save_path)

And the following code just to see them: 而下面的代码只是为了看到它们:

import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(sess,'model.ckpt')
print([v.name for v in tf.get_default_graph().as_graph_def().node])

The first code outputs ['v1:0', 'v2:0'] , as expected. 第一个代码按预期输出['v1:0', 'v2:0'] I am expecting the second code to produce the same result, but i see this: 我期望第二个代码产生相同的结果,但是我看到了:

['v1/Initializer/zeros', 'v1', 'v1/Assign', 'v1/read', 'v2/Initializer/zeros', 'v2', 'v2/Assign', 'v2/read', 'add/y', 'add', 'Assign', 'sub/y', 'sub', 'Assign_1', 'init', 'save/Const', 'save/SaveV2/tensor_names', 'save/SaveV2/shape_and_slices', 'save/SaveV2', 'save/control_dependency', 'save/RestoreV2/tensor_names', 'save/RestoreV2/shape_and_slices', 'save/RestoreV2', 'save/Assign', 'save/RestoreV2_1/tensor_names', 'save/RestoreV2_1/shape_and_slices', 'save/RestoreV2_1', 'save/Assign_1', 'save/restore_all']

I am not sure why tf saves all variables instead of the specifically mentioned two. 我不确定为什么tf保存所有变量而不是特别提到的两个变量。 How can I do that? 我怎样才能做到这一点?

Try the following code from the tensorflow wiki tensorflow Wiki尝试以下代码

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], name="v1")
v2 = tf.get_variable("v2", shape=[5], name="v2")
saver = tf.train.Saver(var_list=[v1, v2]) # list of TF variables that are to be restored

with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "./model.ckpt")
    print("Model restored.")
    # Check the values of the variables
    print("v1 : %s" % v1.eval())
    print("v2 : %s" % v2.eval())

I hope this helps! 我希望这有帮助!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM