[英]Use dictionary in tf.function input_signature in Tensorflow 2.0
I am using Tensorflow 2.0 and facing the following situation:我正在使用 Tensorflow 2.0 并面临以下情况:
@tf.function
def my_fn(items):
.... #do stuff
return
If items is a dict of Tensors like for example:如果 items 是张量的字典,例如:
item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}
Is there a way of using input_signature argument of tf.function so I can force tf2 to avoid creating multiple graphs when item1 is for example tf.zeros([2,1])
?有没有办法使用 tf.function 的 input_signature 参数,以便当 item1 是例如
tf.zeros([2,1])
时,我可以强制 tf2 避免创建多个图形?
The input signature has to be a list, but elements in the list can be dictionaries or lists of Tensor Specs.输入签名必须是一个列表,但列表中的元素可以是字典或张量规范列表。 In your case I would try: (the
name
attributes are optional)在你的情况下,我会尝试:(
name
属性是可选的)
signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
"item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") }
# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
.... # do stuff
return
# calling the TensorFlow function
my_fun(items)
However, if you want to call a particular concrete function created by my_fn
, you have to unpack the dictionary.但是,如果要调用由
my_fn
创建的特定具体函数,则必须解压缩字典。 You also have to provide the name
attribute in tf.TensorSpec
.您还必须在
tf.TensorSpec
提供name
属性。
# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)
This is annoying but should be resolved in TensorFlow 2.3.这很烦人,但应该在 TensorFlow 2.3 中解决。 (see the end of the TF Guide to 'Concrete functions')
(参见“具体函数”TF 指南的结尾)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.