简体   繁体   中英

tensorflow doesn't recognize the graph I import

I'm trying to reuse the graph from another .py file using tf.train.import_meta_graph ()

test.py is the code which I train/save my model. code below is test.py

import tensorflow as tf

W = tf.Variable(tf.random_normal([1]))
b= tf.Variable(tf.random_normal([1]))
X= tf.placeholder(dtype='float32',shape=None)
Y= tf.placeholder(dtype='float32',shape=[None])

Y_ = W*X +b
Y_ =tf.identity(Y_,name="Y_")
tf.add_to_collection("Y_",Y_)
tf.add_to_collection("X",X)
cost = tf.reduce_mean(tf.square(Y_-Y))
train = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)

if __name__ == "__main__":
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        for i in range(10000):
            sess.run([train],feed_dict={X:[1,2,3],Y:[2,4,6]})
            print((sess.run(Y_,feed_dict={X:[1,2,3],Y:[2,4,6]})))
        saver.save(sess,"debug/foo")

test2.py is the code which I load my previous model . The codes below is test2.py

import tensorflow as tf
import test


with tf.Session() as sess:
    import_model = tf.train.import_meta_graph("debug/foo.meta")
    import_model.restore(sess,"debug/foo")
    print("restored")
    result= sess.run(['Y_:0'],feed_dict={'X:0':[1,2,3]})

However,in test2.py ,when I import the graph and try to run it. It gives me the following error

TypeError: Cannot interpret feed_dict key as Tensor: The name 'X:0' refers 
to a Tensor which does not exist. The operation, 'X', does not exist in the 
graph.

What did I do wrong?

I'm using python 3.5 and window 7 and my tensorflow version is 1.2

The tensor does not exists, because your X has no name. You should write

X = tf.placeholder(dtype=tf.float32, name='X')

The following code works:

import tensorflow as tf

X = tf.Variable(tf.random_normal([1]))
Y = tf.placeholder(dtype=tf.float32, name='Y')
Z = tf.add(X, Y, name='sum')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(Z, {Y: 4})

    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, '/tmp/model/my_model')

tf.reset_default_graph()

with tf.Session() as sess:
    loader = tf.train.import_meta_graph('/tmp/model/my_model.meta')
    sess.run(tf.global_variables_initializer())
    loader = loader.restore(sess, '/tmp/model/my_model')
    Z = tf.get_default_graph().get_tensor_by_name('sum:0')

    print sess.run(Z, {'Y:0': 4})
    print sess.run('sum:0', {'Y:0': 4})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM