[英]Tensorflow tf.function conditionals
我花了一些时间来确定问题。 这是:
class ComplicatedStuff:
def __init__(self):
self.result = None
def fun(self, val):
self.result = val
@tf.function
def no_fun(x, blabla):
s = ComplicatedStuff()
# s.do_this(blabla)
# s.do_that(blabla)
if x > .5:
s.fun(2*x)
else:
s.fun(x)
return s.result
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
我真的希望得到2.0
而不是1.0
。 我发现原因是条件在两个分支中都被跟踪,并且因为我使用s
中的副作用返回了一个值,所以只有第二个分支的结果存在。 问题是,我如何绕过这个限制进行编码? 使用返回值可以解决它,但它肯定会使代码变得丑陋,因为 ComplicatedStuff 包装了一堆我不想像那样公开的中间结果。 有更好的选择吗?
我想出的或多或少保留结构的东西是这个hackery:
class ComplicatedStuff(dict):
def __init__(self):
super().__init__()
self.result = None
def fun(self, val):
self.result = val
def __setattr__(self, item, value):
self[item] = value
def __getattribute__(self, item):
if item.startswith("__") or item not in self:
return super().__getattribute__(item)
else:
return self[item]
@tf.function
def no_fun(x, blabla):
s = ComplicatedStuff()
# s.do_this(blabla)
# s.do_that(blabla)
if x > .5:
s.fun(2*x)
s = s
else:
s.fun(x)
s = s
return s.result
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
应该有更好的选择吧?
Tensorflow 自动将一些if
语句转换为tf.function
中的tf.cond
节点。 这称为签名。
由于这在这里不起作用,我们可以自己做:
class ComplicatedStuff:
def __init__(self):
self.result = None
def fun(self, val):
self.result = val
@tf.function
def no_fun(x, blabla):
s = ComplicatedStuff()
# s.do_this(blabla)
# s.do_that(blabla)
x_tmp = tf.cond(x > .5, lambda: 2*x, lambda: x)
s.fun(x_tmp)
return s.result
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
no_fun(tf.constant(0.23), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=0.23>
如文档中所述:
tf.cond跟踪条件的两个分支并将其添加到图中,在执行时动态选择一个分支。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.