繁体   English   中英

TensorFlow 2 如何在 tf.function 中使用 *args?

[英]TensorFlow 2 How to use *args in tf.function?

更新:

进行了更多测试,但我无法通过以下方式重现该行为:

import tensorflow as tf
import numpy as np

@tf.function
def tf_being_unpythonic(an_input, another_input):
    return an_input + another_input

@tf.function
def example(*inputs, other_args = True):
    return tf_being_unpythonic(*inputs)

class TestClass(tf.keras.Model):
    def __init__(self, a, b):
        super().__init__()
        self.a= a
        self.b = b

    @tf.function
    def call(self, *inps, some_kwarg=False):
        if some_kwarg:
            return self.a(*inps)
        return self.b(*inps)

class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.inps = tf.keras.layers.Flatten()
        self.hl1 = tf.keras.layers.Dense(5)
        self.hl2 = tf.keras.layers.Dense(4)
        self.out = tf.keras.layers.Dense(1)

    @tf.function
    def call(self,observation):
        x = self.inps(observation)
        x = self.hl1(x)
        x = self.hl2(x)
        return self.out(x)


class Model2(Model):
    def __init__(self):
        super().__init__()
        self.prein = tf.keras.layers.Concatenate()

    @tf.function
    def call(self,b,c):
        x = self.prein([b,c])
        return super().call(x)   

am = Model()
pm = Model2()
test = TestClass(am,pm)

a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))

test(a,some_kwarg=True)
test(a,b) 

所以这可能是其他地方的错误。

@tf.function
def call(self, *inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    if target:
        return self._target_network(*inp, training)
    return self._network(*inp, training)

我得到:

ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []

但是 print(inp) 给出:

(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,) 

我已经编辑过并且只是未提交的玩具代码,因此无法进一步调查。 将问题留在这里,以便没有遇到此问题的每个人都没有可阅读的内容。

我不认为使用*args构造是tf.function的好习惯。 如您所见,大多数接受可变数量输入的 TF 函数都使用元组。

因此,您可以将函数签名重写为:

def call(self, inputs, target=False, training=False)

并调用它:

instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])

编辑

顺便说一句,我在使用带有*args构造的tf.function时没有看到任何错误:

import tensorflow as tf

@tf.function
def call(*inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    return inp[0]

def main():
    print(call(1))
    print(call(2, 2))
    print(call(3, 3, 3))


if __name__ == '__main__':
    main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

因此,您应该向我们提供有关您尝试执行的操作以及错误所在的更多信息。

这可能是最近解决的错误。 *args**kwargs应该可以正常工作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM