简体   繁体   English

Tensorflow tf.function 条件

[英]Tensorflow tf.function conditionals

It took me some time to pin down the problem.我花了一些时间来确定问题。 Here is is:这是:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
    else:
        s.fun(x)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=1.0>

I would really expect to get 2.0 back instead of 1.0 .我真的希望得到2.0而不是1.0 I figured out the reason is that the conditional is traced in both branches, and because I return a value using a side-effect in s , only the result of the second branch survives.我发现原因是条件在两个分支中都被跟踪,并且因为我使用s中的副作用返回了一个值,所以只有第二个分支的结果存在。 The question is, how do I code around this limitation?问题是,我如何绕过这个限制进行编码? Using return values would solve it, but it will definitely uglify the code because ComplicatedStuff wraps a bunch of intermediate results that I don't want to expose like that.使用返回值可以解决它,但它肯定会使代码变得丑陋,因为 ComplicatedStuff 包装了一堆我不想像那样公开的中间结果。 Is there some better option?有更好的选择吗?

The thing I came up with that more-or-less preserved the structure, is this hackery:我想出的或多或少保留结构的东西是这个hackery:

class ComplicatedStuff(dict):
    def __init__(self):
        super().__init__()
        self.result = None

    def fun(self, val):
        self.result = val
        
    def __setattr__(self, item, value):
        self[item] = value
    
    def __getattribute__(self, item):
        if item.startswith("__") or item not in self:
            return super().__getattribute__(item)
        else:
            return self[item]
        
@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()
    # s.do_this(blabla)
    # s.do_that(blabla)
    if x > .5:
        s.fun(2*x)
        s = s
    else:
        s.fun(x)
        s = s
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

There must be a better option, right?应该有更好的选择吧?

Tensorflow automatically converts some if statements into tf.cond nodes in tf.function . Tensorflow 自动将一些if语句转换为tf.function中的tf.cond节点。 This is called Autograph .这称为签名

Since that is not working here, we can do that ourselves:由于这在这里不起作用,我们可以自己做:

class ComplicatedStuff:
    def __init__(self):
        self.result = None

    def fun(self, val):
        self.result = val

@tf.function
def no_fun(x, blabla):
    s = ComplicatedStuff()  
    # s.do_this(blabla)
    # s.do_that(blabla)

    x_tmp = tf.cond(x > .5, lambda: 2*x, lambda: x)
    s.fun(x_tmp)
    return s.result
    
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>

no_fun(tf.constant(0.23), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=0.23>

As mentioned in the docs:如文档中所述:

tf.cond traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. tf.cond跟踪条件的两个分支并将其添加到图中,在执行时动态选择一个分支。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM