簡體   English   中英

tf.GradientTape() 位置對模型訓練時間的影響

[英]Effect of the position of tf.GradientTape() in model training time

我試圖更新每個時期的權重,但我正在批量處理數據。 問題是,為了規范損失,我需要在訓練循環之外錄制 TensorFlow 變量(要跟蹤和規范化)。 但是當我這樣做時,訓練時間是巨大的。

我認為,它將所有批次的變量累積到圖中並在最后計算梯度。

我已經開始跟蹤 for 循環外和 for 循環內的變量,后者比第一次要快。 我很困惑為什么會發生這種情況,因為無論我做什么,我的模型的可訓練變量和損失都保持不變。

# Very Slow

loss_value = 0
batches = 0

with tf.GradientTape() as tape:
    for inputs, min_seq in zip(dataset, minutes_sequence):
        temp_loss_value = my_loss_function(inputs, min_seq)
        batches +=1
        loss_value = loss_value + temp_loss_value

# The following line takes huge time.
grads = tape.gradient(loss_value, model.trainable_variables)

# Very Fast

loss_value = 0
batches = 0

for inputs, min_seq in zip(dataset, minutes_sequence):
    with tf.GradientTape() as tape:
        temp_loss_value = my_loss_function(inputs, min_seq)
        batches +=1
        loss_value = loss_value + temp_loss_value

# If I do the following line, the graph will break because this are out of tape's scope.
    loss_value = loss_value / batches

# the following line takes huge time
grads = tape.gradient(loss_value, model.trainable_variables)

當我在 for 循環內部聲明 tf.GradientTape() 時,它非常快但我在外面它很慢。

PS - 這是針對自定義損失的,該架構僅包含一個大小為 10 的隱藏層。

我想知道,tf.GradientTape() 位置的不同之處以及它應該如何用於批處理數據集中每個時期的權重更新。

磁帶變量主要用於觀察可訓練的張量變量(記錄變量的先前值和變化值),以便我們可以根據損失函數計算訓練時期的梯度。 它是這里用來記錄變量狀態的 python 上下文管理器構造的實現。 關於 python 上下文管理器的優秀資源在這里 因此,如果在循環內部,它將記錄該前向傳遞的變量(權重),以便我們可以一次計算所有這些變量的梯度(而不是像在沒有像 tensorflow 這樣的庫的幼稚實現中那樣基於堆棧的梯度傳遞) . 如果它在循環之外,它將記錄所有時期的狀態,並且根據 Tensorflow 源代碼,如果使用 TF2.0,它也會刷新,這與模型開發人員必須處理刷新的 TF1.x 不同。 在您的示例中,您沒有設置任何編寫器,但如果設置了任何編寫器,它也會這樣做。 因此,對於上面的代碼,它將繼續記錄(內部使用 Graph.add_to_collection 方法)所有權重,隨着 epochs 的增加,您應該會看到減速。 減速率將與網絡的大小(可訓練變量)和當前紀元數成正比。

所以把它放在循環里面是正確的。 此外,梯度應該應用在 for 循環內部而不是外部(與 with 相同的縮進級別),否則您僅在訓練循環結束時(在最后一個時期之后)應用梯度。 我看到您的訓練對於梯度檢索的當前位置可能不是那么好(之后它被應用到您的代碼中,盡管您在代碼段中省略了它)。

我剛剛找到的關於gradienttape 的另一種好資源

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM