簡體   English   中英

我如何在每個時期而不是每個批次中獲得損失?

[英]How do I get a loss per epoch and not per batch?

在我的理解中,epoch 是在整個數據集上任意頻繁地重復運行,然后分部分處理,即所謂的批處理。 在每次train_on_batch計算一個損失后,更新權重,下一批將獲得更好的結果。 這些損失是我對神經網絡質量和學習狀態的指標。

在幾個來源中,每個時期都會計算(並打印)損失。 因此,我不確定我這樣做是否正確。

目前我的 GAN 看起來像這樣:

for epoch:
  for batch:

    fakes = generator.predict_on_batch(batch)

    dlc = discriminator.train_on_batch(batch, ..)
    dlf = discriminator.train_on_batch(fakes, ..)
    dis_loss_total = 0.5 *  np.add(dlc, dlf)

    g_loss = gan.train_on_batch(batch,..)

    # save losses to array to work with later

這些損失是針對每個批次的。 我如何在一個時代獲得它們? 順便說一句:我需要一個時代的損失,為了什么?

沒有直接的方法來計算一個時期的損失。 實際上,一個 epoch 的損失通常定義為該 epoch 中批次損失的平均值。 因此,您可以在一個 epoch 期間累積損失值,並在最后將其除以該 epoch 中的批次數:

epoch_loss = []
for epoch in range(n_epochs):
    acc_loss = 0.
    for batch in range(n_batches):
        # do the training 
        loss = model.train_on_batch(...)
        acc_loss += loss
    epoch_loss.append(acc_loss / n_batches)

至於另一個問題,epoch loss 的一個用途可能是將其用作停止訓練的指標(但是,通常使用驗證損失,而不是訓練損失)。

我會稍微擴展@today 的回答。 在如何報告一個時期的損失以及如何使用它來確定何時應該停止訓練之間存在一定的平衡。

  • 如果您只查看最近批次的損失,那么對數據集損失的估計將是一個非常嘈雜的估計,因為該批次可能恰好存儲了您的模型遇到問題的所有樣本,或者所有成功的微不足道的樣本.
  • 如果您查看 epoch 中所有批次的平均損失,您可能會得到一個偏斜的響應,因為正如您所指出的,該模型已經(希望)在 epoch 上有所改進,因此初始批次的性能沒有那么有意義與后面批次的性能相比。

准確報告您的 epoch 損失的唯一方法是讓您的模型退出訓練模式,即修復所有模型參數,並在整個數據集上運行您的模型。 這將是對 epoch 損失的無偏計算。 然而,總的來說,這是一個糟糕的主意,因為如果您有一個復雜的模型或大量的訓練數據,您將浪費大量時間這樣做。

因此,我認為最常見的是通過報告N個小批量的平均損失來平衡這些因素,其中N大到足以消除單個批次的噪音,但又不會太大以至於模型性能在第一個和最后一批。

我知道你在使用 Keras,但這里有一個 PyTorch 示例,它清楚地說明了這個概念,復制在這里:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

你可以看到他們累積了N = 2000 個批次的損失,報告了這 2000 個批次的平均損失,然后將運行損失歸零並繼續前進。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM