简体   繁体   中英

Assign op in TensorFlow: what is the return value?

I was trying to build an autoincrementing graph in TensorFlow. I thought that the assign op might be suitable for that, but found no documentation for it.

I assumed that this op returns its value—like in C-like languages—and wrote the following code:

import tensorflow as tf

counter = tf.Variable(0, name="counter")

one = tf.constant(1)
ten = tf.constant(10)

new_counter = tf.add(counter, one)
assign = tf.assign(counter, new_counter)
result = tf.add(assign, ten)

init_op = tf.initialize_all_variables()

with tf.Session() as sess:

  sess.run(init_op)

  for _ in range(3):

    print sess.run(result)

and this code works.

The question is: is this the expected behavior? Why is the assign op not documented here: https://www.tensorflow.org/versions/0.6.0/api_docs/index.html

Is it a non-recommended op?

The tf.assign() operator is the underlying mechanism that implements the Variable.assign() method. It takes a mutable tensor (with tf.*_ref type) and a new value, and returns a mutable tensor that has been updated with the new value. The return value is provided to make it easier to order an assignment before a subsequent read, but this feature is not well documented. An example will hopefully illustrate:

v = tf.Variable(0)
new_v = v.assign(10)
output = v + 5  # `v` is evaluated before or after the assignment.

sess.run(v.initializer)

result, _ = sess.run([output, new_v.op])
print result  # ==> 10 or 15, depending on the order of execution.

v = tf.Variable(0)
new_v = v.assign(10)
output = new_v + 5  # `new_v` is evaluated after the assignment.

sess.run(v.initializer)

result = sess.run([output])
print result  # ==> 15

In your code example the dataflow dependencies enforce the order of execution [read counter] -> new_counter = tf.add(...) -> tf.assign(...) -> [read output of assign] -> result = tf.add(...) , which means that the semantics are unambiguous. However , the read-modify-write steps to update the counter are somewhat inefficient, and can have unexpected behavior when there are multiple steps running concurrently. For example, multiple threads accessing the same variable could observe the counter moving backwards (in the case that an older value was written back after a newer value).

I would recommend that you use Variable.assign_add() to update the counter, as follows:

counter = tf.Variable(0, name="counter")

one = tf.constant(1)
ten = tf.constant(10)

# assign_add ensures that the counter always moves forward.
updated_counter = counter.assign_add(one, use_locking=True)

result = tf.add(updated_counter, ten)
# ...

tf.assign() is nicely documented in the latest versions and it is used frequently in the projects.

This operation outputs "ref" after the assignment is done. This makes it easier to chain operations that need to use the reset value.

In easier words it takes your original tensor and a new tensor. It updates original value of your tensor with a new value and returns the reference of your original tensor.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM