简体   繁体   中英

In Tensorflow, do you need to feed values that aren't relevant to what you need?

Am I correct that in Tensorflow, when I run anything, my feed_dict needs to give values to all my placeholders, even ones that are irrelevant to what I'm running?

In particular I'm thinking of making a prediction, in which case my targets placeholder is irrelevant.

Well, it depends on how your computation graph looks like and how you run the ops which are fed by tensors (here: placeholders ). If there's no dependency on the placeholder in any part of the computation graph that you'll execute in the session, then it does not need to be fed a value. Here's a small example:

In [90]: a = tf.constant([5, 5, 5], tf.float32, name='A')
    ...: b = tf.placeholder(tf.float32, shape=[3], name='B')
    ...: c = tf.constant([3, 3, 3], tf.float32, name='C')
    ...: d = tf.add(a, c, name="Add")
    ...: 
    ...: with tf.Session() as sess:
    ...:       print(sess.run(d))
    ...:

# result       
[8. 8. 8.]

On the other hand, if you execute a part of the computation graph which has a dependency on the placeholder then a value it must be fed else it will raise InvalidArgumentError . Here's an example demonstrating this:

In [89]: a = tf.constant([5, 5, 5], tf.float32, name='A')
    ...: b = tf.placeholder(tf.float32, shape=[3], name='B')
    ...: c = tf.add(a, b, name="Add")
    ...: 
    ...: with tf.Session() as sess:
    ...:       print(sess.run(c))
    ...:       

Executing the above code, throws the following InvalidArgumentError

InvalidArgumentError: You must feed a value for placeholder tensor 'B' with dtype float and shape [3]

[[Node: B = Placeholderdtype=DT_FLOAT, shape=[3], _device="/job:localhost/replica:0/task:0/device: CPU:0"]]


So, to make it work, you've to feed the placeholder using feed_dict as in:

In [91]: a = tf.constant([5, 5, 5], tf.float32, name='A')
    ...: b = tf.placeholder(tf.float32, shape=[3], name='B')
    ...: c = tf.add(a, b, name="Add")
    ...: 
    ...: with tf.Session() as sess:
    ...:       print(sess.run(c, feed_dict={b: [3, 3, 3]}))
    ...:       
    ...:       
[8. 8. 8.]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM