I want to update a two dimensional tf.variable
inside a tf.while_loop
in tensorflow by row. For this reason, I use the tf.assign
method. The problem is that with my implementation and parallel_iterations>1
the result is wrong. With parallel_iterations=1
the result is correct. The code is like this:
a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)
i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)
def body(i, var):
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(a[i], updated_row)
return [tf.add(i, 1), temp]
z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)
The iterations are completely independent and I do not know what is the problem.
Strangely If I change the code like this:
a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)
i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)
def body(i, var):
zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(temp[i], updated_row)
return [tf.add(i, 1), temp]
z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)
the code gives the correct outcome for parallel_iterations>1
. Can someone explain me what is going on here and give me an efficient solution to update the variable because the original variable I want to update is huge and the solution I found is very inefficient.
You do not need to use variables for this, you can just generate the row-updated tensor on the loop body:
import tensorflow as tf
def method(i):
# Placeholder logic
return tf.cast(tf.range(i, i + 100), tf.float32)
def condition(i, var):
return tf.less(i, 100)
def body(i, var):
# Produce new row
updated_row = method(i)
# Index vector that is 1 only on the row to update
idx = tf.equal(tf.range(tf.shape(a)[0]), i)
idx = tf.cast(idx[:, tf.newaxis], var.dtype)
# Compose the new tensor with the old one and the new row
var_updated = (1 - idx) * var + idx * updated_row
return [tf.add(i, 1), var_updated]
# Start with zeros
a = tf.zeros([100, 100], tf.float32)
i = tf.constant(0)
i_end, a_updated = tf.while_loop(condition, body, [i, a], parallel_iterations=10)
with tf.Session() as sess:
print(sess.run(a_updated))
Output:
[[ 0. 1. 2. ... 97. 98. 99.]
[ 1. 2. 3. ... 98. 99. 100.]
[ 2. 3. 4. ... 99. 100. 101.]
...
[ 97. 98. 99. ... 194. 195. 196.]
[ 98. 99. 100. ... 195. 196. 197.]
[ 99. 100. 101. ... 196. 197. 198.]]
In tf.function I've found the following:
Key Point: Any Python side-effects (appending to a list, printing with print, etc) will only happen once, when func is traced. To have side-effects executed into your tf.function they need to be written as TF ops:
I'm pretty sure that's what's going on here. You're expecting a to change but that is a "side effect" ( https://runestone.academy/runestone/books/published/fopp/Functions/SideEffects.html ) which tensorflow does not fully support. When you change a to temp you're no longer relying on the side effect and the code works.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.