I'm trying to reuse the graph from another .py file using tf.train.import_meta_graph ()
test.py is the code which I train/save my model. code below is test.py
import tensorflow as tf
W = tf.Variable(tf.random_normal([1]))
b= tf.Variable(tf.random_normal([1]))
X= tf.placeholder(dtype='float32',shape=None)
Y= tf.placeholder(dtype='float32',shape=[None])
Y_ = W*X +b
Y_ =tf.identity(Y_,name="Y_")
tf.add_to_collection("Y_",Y_)
tf.add_to_collection("X",X)
cost = tf.reduce_mean(tf.square(Y_-Y))
train = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
if __name__ == "__main__":
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(10000):
sess.run([train],feed_dict={X:[1,2,3],Y:[2,4,6]})
print((sess.run(Y_,feed_dict={X:[1,2,3],Y:[2,4,6]})))
saver.save(sess,"debug/foo")
test2.py is the code which I load my previous model . The codes below is test2.py
import tensorflow as tf
import test
with tf.Session() as sess:
import_model = tf.train.import_meta_graph("debug/foo.meta")
import_model.restore(sess,"debug/foo")
print("restored")
result= sess.run(['Y_:0'],feed_dict={'X:0':[1,2,3]})
However,in test2.py ,when I import the graph and try to run it. It gives me the following error
TypeError: Cannot interpret feed_dict key as Tensor: The name 'X:0' refers
to a Tensor which does not exist. The operation, 'X', does not exist in the
graph.
What did I do wrong?
I'm using python 3.5 and window 7 and my tensorflow version is 1.2
The tensor does not exists, because your X
has no name. You should write
X = tf.placeholder(dtype=tf.float32, name='X')
The following code works:
import tensorflow as tf
X = tf.Variable(tf.random_normal([1]))
Y = tf.placeholder(dtype=tf.float32, name='Y')
Z = tf.add(X, Y, name='sum')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(Z, {Y: 4})
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, '/tmp/model/my_model')
tf.reset_default_graph()
with tf.Session() as sess:
loader = tf.train.import_meta_graph('/tmp/model/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, '/tmp/model/my_model')
Z = tf.get_default_graph().get_tensor_by_name('sum:0')
print sess.run(Z, {'Y:0': 4})
print sess.run('sum:0', {'Y:0': 4})
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.