简体   繁体   中英

Tensorflow make assign op an explicit dependency for computing a tensor

I want to be able to implicitly run an assign Op every single time I run another tensor which depends on the tf.Variable which is changed during the assign Op. I don't want to run the assign Op manually every single step. I tried 2 different approaches. Here is a simple example illustration:

target_prob     = tf.placeholder(dtype=tf.float32, shape=[None, 2])
target_var      = tf.Variable(0, trainable=False, dtype=tf.float32)
init_target_var = tf.assign(target_var, tf.zeros_like(target_prob),
                            validate_shape=False)

# First approach
with tf.control_dependencies([init_target_var]):
  result = target_prob + target_var

# Second approach
# [target_var] = tf.tuple([target_var], control_inputs=[init_target_var])
# result = target_prob + target_var

sess = tf.Session()
sess.run(tf.global_variables_initializer())
res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})

Both fail with the error InvalidArgumentError (see above for traceback): Incompatible shapes: [12,2] vs. [10,2] when res2 is being computed. It all works if I instead do:

res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
sess.run(init_target_var, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})

But again, running init_target_var explicitly is exactly what I am trying to avoid.

PS The above is just a simplistic example. My final goal is to use the resulting tensor from tf.scatter_add which unfortunately requires a mutable tensor as input.

For anyone who comes across this, I was actually using the wrong tensor when computing result . The correct code is:

import tensorflow as tf
import numpy as np

target_prob         = tf.placeholder(dtype=tf.float32, shape=[None, 2])
tmp_var             = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
target_var          = tf.assign(tmp_var, tf.zeros_like(target_prob), validate_shape=False)

with tf.control_dependencies([target_var]):
  result = target_prob + target_var

sess = tf.Session()
sess.run(tf.global_variables_initializer())

res1 = sess.run(result, feed_dict={target_prob: np.ones([10, 2], dtype=np.float32)})
res2 = sess.run(result, feed_dict={target_prob: np.ones([12, 2], dtype=np.float32)})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM