簡體   English   中英

使用@tffunction 的 Tensorflow2 警告

[英]Tensorflow2 warning using @tffunction

此示例代碼來自 Tensorflow 2

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

但它正在發出警告。

警告:tensorflow:最近 5 次調用中有 5 次觸發了 tf.function 回溯。 追蹤是昂貴的,過多的追蹤可能是由於傳遞了 python 個對象而不是張量。 此外,tf.function 具有 experimental_relax_shapes=True 選項,可放寬參數形狀,從而避免不必要的回溯。 詳情請參考https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function

有一個更好的方法嗎?

tf.function有一些“特殊性”。 我強烈推薦閱讀這篇文章: https://www.tensorflow.org/tutorials/customization/performance

在這種情況下,問題在於每次使用不同的輸入簽名調用時 function 都會“回溯”(即構建新圖)。 對於張量,輸入簽名指的是 shape 和 dtype,但對於 Python 數字,每個新值都被解釋為“不同”。 在這種情況下,因為您調用 function 時的step變量每次都會發生變化,因此 function 也每次都會回溯。 這對於“真實”代碼(例如在函數內部調用 model)將非常慢。

您可以通過簡單地將step轉換為張量來修復它,在這種情況下,不同的值不會算作新的輸入簽名:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

或使用tf.range直接獲取張量:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

這不應該產生警告(並且速度更快)。

嘗試這個

我使用model(x)而不是model.predict(x)它對我有用




如果您在自定義 function 中遇到此錯誤,請為您的 function 添加shapedtype的固定簽名。

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
...

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM