简体   繁体   English

如何在用@tf.function 装饰的函数内部使用 for 循环来操作和返回 tf.Variable tf.data.Dataset?

[英]How to manipulate and return tf.Variable using a for loop over tf.data.Dataset inside function decorated with @tf.function?

I am trying to create a function containing a for loop over a TensorFlow Dataset that assigns a new value to a TensorFlow Variable in each iteration.我正在尝试创建一个包含 TensorFlow 数据集上的 for 循环的函数,该函数在每次迭代中为 TensorFlow 变量分配一个新值。 The Variable should also be returned as output of the function.变量也应该作为函数的输出返回。 With eager execution enabled, there are no issues, however, in graph mode, some unexpected things seem to happen.启用 Eager Execution 后,没有问题,但是,在图形模式下,似乎会发生一些意想不到的事情。 Consider the following simple dummy code:考虑以下简单的虚拟代码:

import tensorflow as tf


class Test(object):
    def __init__(self):
        self.var = tf.Variable(0, trainable=False, dtype=tf.float32)
        self.increment = tf.constant(1, dtype=tf.float32)
        self.dataset = tf.data.Dataset.from_tensor_slices([0, 1, 2])

    @tf.function
    def fn1(self):
        self.var.assign(0)
        for _ in tf.range(3):
            self.var.assign(self.var+self.increment)
            tf.print(self.var)
        tf.print(self.var)
        return self.var

    @tf.function
    def fn2(self):
        self.var.assign(0)
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            tf.print(self.var)
        tf.print(self.var)
        return self.var

    @tf.function
    def fn3(self):
        self.var.assign(0)
        y = self.var
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            y = self.var
            tf.print(y)
        tf.print(y)
        return y

    @tf.function
    def fn4(self):
        var = 0.0
        for _ in self.dataset:
            var += 1.0
            tf.print(var)
        tf.print(var)
        return var

test.fn1() , test.fn3() and test.fn4() all return the following (desired) output: test.fn1()test.fn3()test.fn4()都返回以下(期望的)输出:

1
2
3
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>

However, test.fn2() behaves differently:但是, test.fn2()行为不同:

1
2
3
0
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Interestingly, after execution of test.fn2 , test.var does seem to contain the correct value:有趣的是,在执行test.fn2test.var似乎包含正确的值:

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>

I am not sure why test.fn2 fails.我不确定为什么test.fn2失败。 Clearly, it is doing some things correctly (as test.var contains the correct value after execution of the function), but it does not deliver the correct result.显然,它正在正确地做一些事情(因为test.var在执行函数后包含正确的值),但它没有提供正确的结果。 Can you help me understand what causes this code to fail?你能帮我理解是什么导致这段代码失败吗?

The behavior described above occurs when using TensorFlow 2.1.0 for Python 3.6 on CentOS 7.在 CentOS 7 上将 TensorFlow 2.1.0 用于 Python 3.6 时会发生上述行为。

Running this on TensorFlow 2.1.0 reproduces your scenario.TensorFlow 2.1.0上运行它会重现您的场景。

Which prints 1 2 3 0 for test.fn2() , but you should also consider that when you print self.var in test.fn3() it will also show you self.var = 0 during the function call.这会为test.fn2()打印1 2 3 0 ,但您还应该考虑到,当您在test.fn3()打印self.var ,它还会在函数调用期间向您显示self.var = 0

Modified fn3( ) :修改后的fn3( )

    @tf.function
    def fn3(self):
        self.var.assign(0)
        y = self.var
        for _ in self.dataset:
            self.var.assign(self.var+self.increment)
            y = self.var
            tf.print(y)
        tf.print(self.var)  # Inspect self.var value
        tf.print(y)
        return y

Output:输出:

# Executed in Tensorflow 2.1.0
# test.fn3()
1
2
3
0  << self.var
3

This is already fixed If you execute this in Tensorflow 2.2.0-rc2 .如果你在Tensorflow 2.2.0-rc2 中执行它,这已经修复了。
The output will be your desired outcome even when printing it during graph execution.即使在图形执行期间打印输出,输出也将是您想要的结果。

To quickly simulate this you could use Google Colab and use %tensorflow_version 2.x to get the latest available version for Tensorflow .要快速模拟这个你可以使用谷歌Colab和使用%tensorflow_version 2.x以获取Tensorflow最新版本

Output:输出:

# Executed in Tensorflow 2.2.0-rc2
Function 1
1
2
3
3
Function 2
1
2
3
3
Function 3
1
2
3
3 << Value of self.var in test.fn3()
3
Function 4
1
2
3
3

You could check more about the changes in the latest Tensorflow Updates in this link .您可以在此链接中查看有关最新 Tensorflow 更新更改的更多信息。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM