![](/img/trans.png)
[英]Use dictionary in tf.function input_signature in Tensorflow 2.0
[英]Tensorflow 2 tf.function input_signature for a list input
為了使用 saved_model api 導出我的saved_model
,我需要定義要在加載后調用的每個方法的input_signature
。 我不知道如何判斷輸入是一個可變長度的列表(例如tf.keras.Model.call
)。
SO上有一個關於input_signature
的未回答問題列表:
還有這個關於*args
的: TensorFlow 2 如何在 tf.function 中使用 *args? 但它不處理saved_model
的問題。
也許您可以使用張量而不是列表作為輸入?
然后在tf.TensorSpec
中指定一個[None]
維度以允許跟蹤重用的靈活性。
由於 TensorFlow 根據張量的形狀匹配張量,因此使用None
維度作為通配符將允許函數重用跟蹤以用於可變大小的輸入。 如果您有不同長度的序列,或者每個批次有不同大小的圖像,則可能會出現可變大小的輸入。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.