[英]How to manipulate and return tf.Variable using a for loop over tf.data.Dataset inside function decorated with @tf.function?
I am trying to create a function containing a for loop over a TensorFlow Dataset that assigns a new value to a TensorFlow Variable in each iteration.我正在尝试创建一个包含 TensorFlow 数据集上的 for 循环的函数,该函数在每次迭代中为 TensorFlow 变量分配一个新值。 The Variable should also be returned as output of the function.变量也应该作为函数的输出返回。 With eager execution enabled, there are no issues, however, in graph mode, some unexpected things seem to happen.启用 Eager Execution 后,没有问题,但是,在图形模式下,似乎会发生一些意想不到的事情。 Consider the following simple dummy code:考虑以下简单的虚拟代码:
import tensorflow as tf
class Test(object):
def __init__(self):
self.var = tf.Variable(0, trainable=False, dtype=tf.float32)
self.increment = tf.constant(1, dtype=tf.float32)
self.dataset = tf.data.Dataset.from_tensor_slices([0, 1, 2])
@tf.function
def fn1(self):
self.var.assign(0)
for _ in tf.range(3):
self.var.assign(self.var+self.increment)
tf.print(self.var)
tf.print(self.var)
return self.var
@tf.function
def fn2(self):
self.var.assign(0)
for _ in self.dataset:
self.var.assign(self.var+self.increment)
tf.print(self.var)
tf.print(self.var)
return self.var
@tf.function
def fn3(self):
self.var.assign(0)
y = self.var
for _ in self.dataset:
self.var.assign(self.var+self.increment)
y = self.var
tf.print(y)
tf.print(y)
return y
@tf.function
def fn4(self):
var = 0.0
for _ in self.dataset:
var += 1.0
tf.print(var)
tf.print(var)
return var
test.fn1()
, test.fn3()
and test.fn4()
all return the following (desired) output: test.fn1()
、 test.fn3()
和test.fn4()
都返回以下(期望的)输出:
1
2
3
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
However, test.fn2()
behaves differently:但是, test.fn2()
行为不同:
1
2
3
0
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
Interestingly, after execution of test.fn2
, test.var
does seem to contain the correct value:有趣的是,在执行test.fn2
, test.var
似乎包含正确的值:
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>
I am not sure why test.fn2
fails.我不确定为什么test.fn2
失败。 Clearly, it is doing some things correctly (as test.var
contains the correct value after execution of the function), but it does not deliver the correct result.显然,它正在正确地做一些事情(因为test.var
在执行函数后包含正确的值),但它没有提供正确的结果。 Can you help me understand what causes this code to fail?你能帮我理解是什么导致这段代码失败吗?
The behavior described above occurs when using TensorFlow 2.1.0 for Python 3.6 on CentOS 7.在 CentOS 7 上将 TensorFlow 2.1.0 用于 Python 3.6 时会发生上述行为。
Running this on TensorFlow 2.1.0 reproduces your scenario.在TensorFlow 2.1.0上运行它会重现您的场景。
Which prints 1 2 3 0
for test.fn2()
, but you should also consider that when you print self.var
in test.fn3()
it will also show you self.var = 0
during the function call.这会为test.fn2()
打印1 2 3 0
,但您还应该考虑到,当您在test.fn3()
打印self.var
,它还会在函数调用期间向您显示self.var = 0
。
Modified fn3( ) :修改后的fn3( ) :
@tf.function
def fn3(self):
self.var.assign(0)
y = self.var
for _ in self.dataset:
self.var.assign(self.var+self.increment)
y = self.var
tf.print(y)
tf.print(self.var) # Inspect self.var value
tf.print(y)
return y
Output:输出:
# Executed in Tensorflow 2.1.0
# test.fn3()
1
2
3
0 << self.var
3
This is already fixed If you execute this in Tensorflow 2.2.0-rc2 .如果你在Tensorflow 2.2.0-rc2 中执行它,这已经修复了。
The output will be your desired outcome even when printing it during graph execution.即使在图形执行期间打印输出,输出也将是您想要的结果。
To quickly simulate this you could use Google Colab and use %tensorflow_version 2.x
to get the latest available version for Tensorflow .要快速模拟这个你可以使用谷歌Colab和使用%tensorflow_version 2.x
以获取Tensorflow最新版本。
Output:输出:
# Executed in Tensorflow 2.2.0-rc2
Function 1
1
2
3
3
Function 2
1
2
3
3
Function 3
1
2
3
3 << Value of self.var in test.fn3()
3
Function 4
1
2
3
3
You could check more about the changes in the latest Tensorflow Updates in this link .您可以在此链接中查看有关最新 Tensorflow 更新更改的更多信息。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.