简体   繁体   中英

tf.while_loop gives wrong result when it runs in parallel

I want to update a two dimensional tf.variable inside a tf.while_loop in tensorflow by row. For this reason, I use the tf.assign method. The problem is that with my implementation and parallel_iterations>1 the result is wrong. With parallel_iterations=1 the result is correct. The code is like this:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(a[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

The iterations are completely independent and I do not know what is the problem.

Strangely If I change the code like this:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
    temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(temp[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

the code gives the correct outcome for parallel_iterations>1 . Can someone explain me what is going on here and give me an efficient solution to update the variable because the original variable I want to update is huge and the solution I found is very inefficient.

You do not need to use variables for this, you can just generate the row-updated tensor on the loop body:

import tensorflow as tf

def method(i):
    # Placeholder logic
    return tf.cast(tf.range(i, i + 100), tf.float32)

def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    # Produce new row
    updated_row = method(i)
    # Index vector that is 1 only on the row to update
    idx = tf.equal(tf.range(tf.shape(a)[0]), i)
    idx = tf.cast(idx[:, tf.newaxis], var.dtype)
    # Compose the new tensor with the old one and the new row
    var_updated = (1 - idx) * var + idx * updated_row
    return [tf.add(i, 1), var_updated]

# Start with zeros
a = tf.zeros([100, 100], tf.float32)
i = tf.constant(0)
i_end, a_updated = tf.while_loop(condition, body, [i, a], parallel_iterations=10)

with tf.Session() as sess:
    print(sess.run(a_updated))

Output:

[[  0.   1.   2. ...  97.  98.  99.]
 [  1.   2.   3. ...  98.  99. 100.]
 [  2.   3.   4. ...  99. 100. 101.]
 ...
 [ 97.  98.  99. ... 194. 195. 196.]
 [ 98.  99. 100. ... 195. 196. 197.]
 [ 99. 100. 101. ... 196. 197. 198.]]

In tf.function I've found the following:

Key Point: Any Python side-effects (appending to a list, printing with print, etc) will only happen once, when func is traced. To have side-effects executed into your tf.function they need to be written as TF ops:

I'm pretty sure that's what's going on here. You're expecting a to change but that is a "side effect" ( https://runestone.academy/runestone/books/published/fopp/Functions/SideEffects.html ) which tensorflow does not fully support. When you change a to temp you're no longer relying on the side effect and the code works.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM