[英]Tensorflow 2 tf.function input_signature for a list input
In order to export my model with the saved_model
api, I need to define the input_signature
of each method intended to be called after loading.为了使用 saved_model api 导出我的
saved_model
,我需要定义要在加载后调用的每个方法的input_signature
。 I don't know how to tell that the input is a list with variable length (as it is for tf.keras.Model.call
for instance).我不知道如何判断输入是一个可变长度的列表(例如
tf.keras.Model.call
)。
There is a list of unanswered questions about input_signature
on SO: SO上有一个关于
input_signature
的未回答问题列表:
and also this one about *args
: TensorFlow 2 How to use *args in tf.function?还有这个关于
*args
的: TensorFlow 2 如何在 tf.function 中使用 *args? but it does not handle the problem of saved_model
.但它不处理
saved_model
的问题。
Maybe you could use a Tensor instead of a list as input?也许您可以使用张量而不是列表作为输入?
Then specify a [None]
dimension in tf.TensorSpec
to allow for flexibility in trace reuse.然后在
tf.TensorSpec
中指定一个[None]
维度以允许跟踪重用的灵活性。
Since TensorFlow matches tensors based on their shape, using a None
dimension as a wildcard will allow Functions to reuse traces for variably-sized input.由于 TensorFlow 根据张量的形状匹配张量,因此使用
None
维度作为通配符将允许函数重用跟踪以用于可变大小的输入。 Variably-sized input can occur if you have sequences of different length, or images of different sizes for each batch.如果您有不同长度的序列,或者每个批次有不同大小的图像,则可能会出现可变大小的输入。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.