繁体   English   中英

Tensorflow 2.0 中的 tf.function 和 tf.while 循环

[英]tf.function and tf.while loop in Tensorflow 2.0

我正在尝试使用tf.while_loop并行化循环。 正如这里所建议的, parallel_iterations参数在 Eager 模式中没有区别。 所以我试图换tf.while_looptf.function 但是,添加装饰器后,迭代变量的行为发生了变化。

例如,这段代码有效。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)
tf.while_loop(c, print_fun, [iteration])

如果我添加装饰器,就会出现错误。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)

@tf.function
def run_graph():
    iteration = tf.constant(0)
    tf.while_loop(c, print_fun, [iteration])

run_graph()

从我的调试过程中,我发现变量iteration从张量变为占位符。 这是为什么? 我应该如何修改代码以消除错误?

谢谢。

您的第一个代码段(没有@tf.function的代码段)利用 TensorFlow 2 的急切执行直接操作一个 numpy 数组(即您的外部iteration对象)。 使用@tf.function ,这不起作用,因为 @tf.function 尝试将您的代码编译成 tf.Graph,它不能直接对 numpy 数组进行操作(它只能处理 tensorflow 张量)。 要解决此问题,请使用 tf.Variable 并继续为其切片分配值。

使用@tf.function ,通过利用@tf.function的自动 Python 到图形转换功能(称为 AutoGraph),您实际上可以使用更简单的代码来实现。 您只需编写一个普通的 Python while 循环(使用tf.less()代替<运算符),而 while 循环将被 AutoGraph 编译成 tf.while_loop 底层。

代码看起来像:

result = tf.Variable(np.zeros([10], dtype=np.int32))

@tf.function
def run_graph():
  i = tf.constant(0, dtype=tf.int32)
  while tf.less(i, 10):
    result[i].assign(i)  # Performance may require tuning here.
    i += 1

run_graph()
print(result.read_value())

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM