[英]TensorFlow 2 How to use *args in tf.function?
Did a bit more testing and I can't reproduce the behaviour with:进行了更多测试,但我无法通过以下方式重现该行为:
import tensorflow as tf
import numpy as np
@tf.function
def tf_being_unpythonic(an_input, another_input):
return an_input + another_input
@tf.function
def example(*inputs, other_args = True):
return tf_being_unpythonic(*inputs)
class TestClass(tf.keras.Model):
def __init__(self, a, b):
super().__init__()
self.a= a
self.b = b
@tf.function
def call(self, *inps, some_kwarg=False):
if some_kwarg:
return self.a(*inps)
return self.b(*inps)
class Model(tf.keras.Model):
def __init__(self):
super().__init__()
self.inps = tf.keras.layers.Flatten()
self.hl1 = tf.keras.layers.Dense(5)
self.hl2 = tf.keras.layers.Dense(4)
self.out = tf.keras.layers.Dense(1)
@tf.function
def call(self,observation):
x = self.inps(observation)
x = self.hl1(x)
x = self.hl2(x)
return self.out(x)
class Model2(Model):
def __init__(self):
super().__init__()
self.prein = tf.keras.layers.Concatenate()
@tf.function
def call(self,b,c):
x = self.prein([b,c])
return super().call(x)
am = Model()
pm = Model2()
test = TestClass(am,pm)
a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))
test(a,some_kwarg=True)
test(a,b)
So it's probably a bug somewhere else.所以这可能是其他地方的错误。
@tf.function
def call(self, *inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
if target:
return self._target_network(*inp, training)
return self._network(*inp, training)
I get:我得到:
ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []
But print(inp) gives:但是 print(inp) 给出:
(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,)
I've since edited and was just uncommited toy code so can't investigate further.我已经编辑过并且只是未提交的玩具代码,因此无法进一步调查。 Will leave the question here so that everyone who doesn't get this issue won't have something to read.
将问题留在这里,以便没有遇到此问题的每个人都没有可阅读的内容。
I don't think that using a *args
construct is a good practice for a tf.function
.我不认为使用
*args
构造是tf.function
的好习惯。 As you can see, most of the TF functions accepting a variable number of inputs use a tuple.如您所见,大多数接受可变数量输入的 TF 函数都使用元组。
So, you can rewrite your function signature as:因此,您可以将函数签名重写为:
def call(self, inputs, target=False, training=False)
and calling it with:并调用它:
instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])
By the way, I don't see any error while using tf.function
with a *args
construct:顺便说一句,我在使用带有
*args
构造的tf.function
时没有看到任何错误:
import tensorflow as tf
@tf.function
def call(*inp, target=False, training=False):
if not len(inp):
raise ValueError("Call requires some input")
return inp[0]
def main():
print(call(1))
print(call(2, 2))
print(call(3, 3, 3))
if __name__ == '__main__':
main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
So you should provide us more informations about what you try to do and where the error is.因此,您应该向我们提供有关您尝试执行的操作以及错误所在的更多信息。
This may have been a bug that was resolved recently.这可能是最近解决的错误。
*args
and **kwargs
should work fine. *args
和**kwargs
应该可以正常工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.