[英]Tensorflow2 warning using @tffunction
此示例代碼來自 Tensorflow 2
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()
但它正在發出警告。
警告:tensorflow:最近 5 次調用中有 5 次觸發了 tf.function 回溯。 追蹤是昂貴的,過多的追蹤可能是由於傳遞了 python 個對象而不是張量。 此外,tf.function 具有 experimental_relax_shapes=True 選項,可放寬參數形狀,從而避免不必要的回溯。 詳情請參考https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args和https://www.tensorflow.org/api_docs/python/tf/function 。
有一個更好的方法嗎?
tf.function
有一些“特殊性”。 我強烈推薦閱讀這篇文章: https://www.tensorflow.org/tutorials/customization/performance
在這種情況下,問題在於每次使用不同的輸入簽名調用時 function 都會“回溯”(即構建新圖)。 對於張量,輸入簽名指的是 shape 和 dtype,但對於 Python 數字,每個新值都被解釋為“不同”。 在這種情況下,因為您調用 function 時的step
變量每次都會發生變化,因此 function 也每次都會回溯。 這對於“真實”代碼(例如在函數內部調用 model)將非常慢。
您可以通過簡單地將step
轉換為張量來修復它,在這種情況下,不同的值不會算作新的輸入簽名:
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()
或使用tf.range
直接獲取張量:
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()
這不應該產生警告(並且速度更快)。
我使用model(x)
而不是model.predict(x)
它對我有用
如果您在自定義 function 中遇到此錯誤,請為您的 function 添加shape
和dtype
的固定簽名。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
...
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.