[英]TensorFlow 2 How to use *args in tf.function?
进行了更多测试,但我无法通过以下方式重现该行为:
import tensorflow as tf
import numpy as np
@tf.function
def tf_being_unpythonic(an_input, another_input):
return an_input + another_input
@tf.function
def example(*inputs, other_args = True):
return tf_being_unpythonic(*inputs)
class TestClass(tf.keras.Model):
def __init__(self, a, b):
super().__init__()
self.a= a
self.b = b
@tf.function
def call(self, *inps, some_kwarg=False):
if some_kwarg:
return self.a(*inps)
return self.b(*inps)
class Model(tf.keras.Model):
def __init__(self):
super().__init__()
self.inps = tf.keras.layers.Flatten()
self.hl1 = tf.keras.layers.Dense(5)
self.hl2 = tf.keras.layers.Dense(4)
self.out = tf.keras.layers.Dense(1)
@tf.function
def call(self,observation):
x = self.inps(observation)
x = self.hl1(x)
x = self.hl2(x)
return self.out(x)
class Model2(Model):
def __init__(self):
super().__init__()
self.prein = tf.keras.layers.Concatenate()
@tf.function
def call(self,b,c):
x = self.prein([b,c])
return super().call(x)
am = Model()
pm = Model2()
test = TestClass(am,pm)
a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))
test(a,some_kwarg=True)
test(a,b)
所以这可能是其他地方的错误。
@tf.function
def call(self, *inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
if target:
return self._target_network(*inp, training)
return self._network(*inp, training)
我得到:
ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []
但是 print(inp) 给出:
(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,)
我已经编辑过并且只是未提交的玩具代码,因此无法进一步调查。 将问题留在这里,以便没有遇到此问题的每个人都没有可阅读的内容。
我不认为使用*args
构造是tf.function
的好习惯。 如您所见,大多数接受可变数量输入的 TF 函数都使用元组。
因此,您可以将函数签名重写为:
def call(self, inputs, target=False, training=False)
并调用它:
instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])
顺便说一句,我在使用带有*args
构造的tf.function
时没有看到任何错误:
import tensorflow as tf
@tf.function
def call(*inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
return inp[0]
def main():
print(call(1))
print(call(2, 2))
print(call(3, 3, 3))
if __name__ == '__main__':
main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
因此,您应该向我们提供有关您尝试执行的操作以及错误所在的更多信息。
这可能是最近解决的错误。 *args
和**kwargs
应该可以正常工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.