簡體   English   中英

Tensorflow 2.0 中的 tf.function 和 tf.while 循環

[英]tf.function and tf.while loop in Tensorflow 2.0

我正在嘗試使用tf.while_loop並行化循環。 正如這里所建議的, parallel_iterations參數在 Eager 模式中沒有區別。 所以我試圖換tf.while_looptf.function 但是,添加裝飾器后,迭代變量的行為發生了變化。

例如,這段代碼有效。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)
tf.while_loop(c, print_fun, [iteration])

如果我添加裝飾器,就會出現錯誤。

result = np.zeros(10)
iteration = tf.constant(0)
c = lambda i: tf.less(i, 10)
def print_fun(iteration):
    result[iteration] = iteration
    iteration += 1
    return (iteration,)

@tf.function
def run_graph():
    iteration = tf.constant(0)
    tf.while_loop(c, print_fun, [iteration])

run_graph()

從我的調試過程中,我發現變量iteration從張量變為占位符。 這是為什么? 我應該如何修改代碼以消除錯誤?

謝謝。

您的第一個代碼段(沒有@tf.function的代碼段)利用 TensorFlow 2 的急切執行直接操作一個 numpy 數組(即您的外部iteration對象)。 使用@tf.function ,這不起作用,因為 @tf.function 嘗試將您的代碼編譯成 tf.Graph,它不能直接對 numpy 數組進行操作(它只能處理 tensorflow 張量)。 要解決此問題,請使用 tf.Variable 並繼續為其切片分配值。

使用@tf.function ,通過利用@tf.function的自動 Python 到圖形轉換功能(稱為 AutoGraph),您實際上可以使用更簡單的代碼來實現。 您只需編寫一個普通的 Python while 循環(使用tf.less()代替<運算符),而 while 循環將被 AutoGraph 編譯成 tf.while_loop 底層。

代碼看起來像:

result = tf.Variable(np.zeros([10], dtype=np.int32))

@tf.function
def run_graph():
  i = tf.constant(0, dtype=tf.int32)
  while tf.less(i, 10):
    result[i].assign(i)  # Performance may require tuning here.
    i += 1

run_graph()
print(result.read_value())

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM