简体   繁体   中英

How to restore a saved tensorflow model?

There are two python files, The first one is for saving the tensorflow model. The second one is for restoring the saved model.

Question:

  1. When I run the two files one after another, it's ok.

  2. When I run the first one, restart the edit and run the second one,it tells me that the w1 is not defined?

What I want to do is:

  1. Save a tensorflow model

  2. Restore the saved model

What's wrong with it? Thanks for your kindly help?

model_save.py

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()

with tf.Session() as sess: 
sess.run(tf.global_variables_initializer())
saver.save(sess, 'SR\\my-model')

model_restore.py

import tensorflow as tf

with tf.Session() as sess:    
saver = tf.train.import_meta_graph('SR\\my-model.meta')
saver.restore(sess,'SR\\my-model')
print (sess.run(w1))

在此处输入图片说明

Briefly, you should use

print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))

instead of print (sess.run(w1)) in your model_restore.py file.

model_save.py

import tensorflow as tf
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print(w1_node.eval()) # [ 0.43350926  1.02784836]
  #print(w1.eval()) # NameError: name 'w1' is not defined
  saver.save(sess, 'my-model')

w1_node is only defined in model_save.py , and model_restore.py file can't recognize it. When we call a Tensor variable by its name , we should use get_tensor_by_name , as this post Tensorflow: How to get a tensor by name? suggested.

model_restore.py

import tensorflow as tf

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('my-model.meta')
  saver.restore(sess,'my-model')
  print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0')))
  # [ 0.43350926  1.02784836]
  print(tf.global_variables()) # print tensor variables
  # [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>,
  #  <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>]
  for op in tf.get_default_graph().get_operations():
    print str(op.name) # print all the operation nodes' name

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM