[英]tf.function and tf.while loop in Tensorflow 2.0
I am trying to parallelize loop using tf.while_loop
.我正在尝试使用
tf.while_loop
并行化循环。 As suggested here , the parallel_iterations
argument doesn't make a difference in the eager mode.正如这里所建议的,
parallel_iterations
参数在 Eager 模式中没有区别。 So I attempted to wrap tf.while_loop
with tf.function
.所以我试图换
tf.while_loop
与tf.function
。 However, after adding the decorator,the behavior of the iteration variable changes.但是,添加装饰器后,迭代变量的行为发生了变化。
For example, this piece of code works.例如,这段代码有效。
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
tf.while_loop(c, print_fun, [iteration])
If I add the decorator, bug occurs.如果我添加装饰器,就会出现错误。
result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
result[iteration] = iteration
iteration += 1
return (iteration,)
@tf.function
def run_graph():
iteration = tf.constant(0)
tf.while_loop(c, print_fun, [iteration])
run_graph()
From my debugging process, I found that variable iteration
changes from a tensor to a placeholder.从我的调试过程中,我发现变量
iteration
从张量变为占位符。 Why is that?这是为什么? How should I modify the code to eliminate the bug?
我应该如何修改代码以消除错误?
Thanks.谢谢。
The code in your first snippet (the one without the @tf.function
) takes advantage of TensorFlow 2's eager execution to manipulate a numpy array (ie, your outer iteration
object) directly.您的第一个代码段(没有
@tf.function
的代码段)利用 TensorFlow 2 的急切执行直接操作一个 numpy 数组(即您的外部iteration
对象)。 With @tf.function
, this doesn't work because @tf.function tries to compile your code into a tf.Graph, which cannot operate on a numpy array directly (it can only process tensorflow tensors).使用
@tf.function
,这不起作用,因为 @tf.function 尝试将您的代码编译成 tf.Graph,它不能直接对 numpy 数组进行操作(它只能处理 tensorflow 张量)。 To get around this issue, use a tf.Variable and keep assigning value into its slices.要解决此问题,请使用 tf.Variable 并继续为其切片分配值。
With @tf.function
, what you are trying to do is actually achievable with simpler code, by taking advantage of @tf.function
's automatic Python-to-graph transformation feature (known as AutoGraph).使用
@tf.function
,通过利用@tf.function
的自动 Python 到图形转换功能(称为 AutoGraph),您实际上可以使用更简单的代码来实现。 You just write a normal Python while loop (using tf.less()
in lieu of the <
operator), and the while loop will be compiled by AutoGraph into a tf.while_loop under the hood.您只需编写一个普通的 Python while 循环(使用
tf.less()
代替<
运算符),而 while 循环将被 AutoGraph 编译成 tf.while_loop 底层。
The code looks something like:代码看起来像:
result = tf.Variable(np.zeros([10], dtype=np.int32))
@tf.function
def run_graph():
i = tf.constant(0, dtype=tf.int32)
while tf.less(i, 10):
result[i].assign(i) # Performance may require tuning here.
i += 1
run_graph()
print(result.read_value())
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.