簡體   English   中英

Tensorflow 2 tf.function input_signature 用於列表輸入

[英]Tensorflow 2 tf.function input_signature for a list input

為了使用 saved_model api 導出我的saved_model ,我需要定義要在加載后調用的每個方法的input_signature 我不知道如何判斷輸入是一個可變長度的列表(例如tf.keras.Model.call )。

SO上有一個關於input_signature的未回答問題列表:

還有這個關於*args的: TensorFlow 2 如何在 tf.function 中使用 *args? 但它不處理saved_model的問題。

也許您可以使用張量而不是列表作為輸入?

然后在tf.TensorSpec中指定一個[None]維度以允許跟蹤重用的靈活性。

由於 TensorFlow 根據張量的形狀匹配張量,因此使用None維度作為通配符將允許函數重用跟蹤以用於可變大小的輸入。 如果您有不同長度的序列,或者每個批次有不同大小的圖像,則可能會出現可變大小的輸入。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM