繁体   English   中英

Tensorflow tf.function 条件

[英]Tensorflow tf.function conditionals

我花了一些时间来确定问题。 这是:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
    else:
        s.fun(x)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

我真的希望得到2.0而不是1.0 我发现原因是条件在两个分支中都被跟踪,并且因为我使用s中的副作用返回了一个值,所以只有第二个分支的结果存在。 问题是,我如何绕过这个限制进行编码? 使用返回值可以解决它,但它肯定会使代码变得丑陋,因为 ComplicatedStuff 包装了一堆我不想像那样公开的中间结果。 有更好的选择吗?

我想出的或多或少保留结构的东西是这个hackery:

class ComplicatedStuff(dict):
    def __init__(self):
        super().__init__()
        self.result = None

    def fun(self, val):
        self.result = val
        
    def __setattr__(self, item, value):
        self[item] = value
    
    def __getattribute__(self, item):
        if item.startswith("__") or item not in self:
            return super().__getattribute__(item)
        else:
            return self[item]
        
@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
        s = s
    else:
        s.fun(x)
        s = s
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

应该有更好的选择吧?

Tensorflow 自动将一些if语句转换为tf.function中的tf.cond节点。 这称为签名

由于这在这里不起作用,我们可以自己做:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)

    x_tmp = tf.cond(x > .5, lambda: 2*x, lambda: x)
    s.fun(x_tmp)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

no_fun(tf.constant(0.23), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=0.23>

如文档中所述:

tf.cond跟踪条件的两个分支并将其添加到图中,在执行时动态选择一个分支。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM