繁体   English   中英

tf.while_loop 在并行运行时给出错误的结果

[英]tf.while_loop gives wrong result when it runs in parallel

我想在tf.while_loop中逐行更新tf.while_loop中的二维tf.variable 为此,我使用tf.assign方法。 问题是我的实现和parallel_iterations>1结果是错误的。 使用parallel_iterations=1结果是正确的。 代码是这样的:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(a[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

迭代是完全独立的,我不知道是什么问题。

奇怪的是,如果我像这样更改代码:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
    temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
    updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
    temp = tf.assign(temp[i], updated_row)
    return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

代码给出了parallel_iterations>1的正确结果。 有人可以向我解释这里发生了什么,并给我一个有效的解决方案来更新变量,因为我要更新的原始变量很大,而我发现的解决方案效率很低。

您不需要为此使用变量,您只需在循环体上生成行更新张量:

import tensorflow as tf

def method(i):
    # Placeholder logic
    return tf.cast(tf.range(i, i + 100), tf.float32)

def condition(i, var):
    return tf.less(i, 100)

def body(i, var):
    # Produce new row
    updated_row = method(i)
    # Index vector that is 1 only on the row to update
    idx = tf.equal(tf.range(tf.shape(a)[0]), i)
    idx = tf.cast(idx[:, tf.newaxis], var.dtype)
    # Compose the new tensor with the old one and the new row
    var_updated = (1 - idx) * var + idx * updated_row
    return [tf.add(i, 1), var_updated]

# Start with zeros
a = tf.zeros([100, 100], tf.float32)
i = tf.constant(0)
i_end, a_updated = tf.while_loop(condition, body, [i, a], parallel_iterations=10)

with tf.Session() as sess:
    print(sess.run(a_updated))

输出:

[[  0.   1.   2. ...  97.  98.  99.]
 [  1.   2.   3. ...  98.  99. 100.]
 [  2.   3.   4. ...  99. 100. 101.]
 ...
 [ 97.  98.  99. ... 194. 195. 196.]
 [ 98.  99. 100. ... 195. 196. 197.]
 [ 99. 100. 101. ... 196. 197. 198.]]

在 tf.function 我发现了以下内容:

关键点:任何 Python 副作用(附加到列表、打印打印等)只会在跟踪 func 时发生一次。 要将副作用执行到您的 tf.function 中,它们需要编写为 TF ops:

我很确定这就是这里发生的事情。 你期待一个改变,但是这是一个“副作用”( https://runestone.academy/runestone/books/published/fopp/Functions/SideEffects.html )的tensorflow不完全支持。 当您将 a 更改为 temp 时,您不再依赖副作用并且代码有效。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM