簡體   English   中英

更大的batch size會導致更大的損失

[英]Larger batch size cause larger loss

我正在嘗試使用 pytorch 解決回歸問題。 我有一個預先訓練的 model 開始。 當我調整超參數時,我發現我的批量大小和訓練/驗證損失有一個奇怪的相關性。 具體來說:

batch size = 16 -\> train/val loss around 0.6 (for epoch 1)
batch size = 64 -\> train/val loss around 0.8 (for epoch 1)
batch size = 128 -\> train/val loss around 1 (for epoch 1)

我想知道這是否正常,或者我的代碼有問題。

優化器:SGD,學習率為 1e-3

損耗 function:

def rmse(pred, real):
    residuals = pred - real
    square = torch.square(residuals)
    sum_of_square = torch.sum(square)
    mean = sum_of_square / pred.shape[0]
    root = torch.sqrt(mean)
    return root

火車循環:

def train_loop(dataloader, model, optimizer, epoch):
    num_of_batches = len(dataloader)
    total_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        optimizer.zero_grad()
        
        pred = model(X)
        loss = rmse(pred, y)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        #lr_scheduler.step(epoch*num_of_batches+batch)
        #last_lr = lr_scheduler.get_last_lr()[0]

    train_loss = total_loss / num_of_batches
    return train_loss

測試循環:

def test_loop(dataloader, model):
    size = len(dataloader.dataset)
    num_of_batches = len(dataloader)
    test_loss = 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += rmse(pred, y).item()

    test_loss /= num_of_batches
    return test_loss

除非您設置手動 rng 種子,否則第一個 epoch 的第一批在運行之間總是會非常不一致。 您的損失是由於您的隨機初始化權重與您的隨機子樣本批次的訓練項目的效果如何。 換句話說,無論批量大小如何,它在第一次復飛時的損失是沒有意義的(在這種情況下)。

我將從a開始。 類比, b. 深入數學,然后是 c。 以數值實驗結束。

a.)您所看到的現象與隨機梯度下降和批量梯度下降之間的差異大致相同。 在模擬情況下,學習參數應該移動的“真實”梯度或方向可以最大限度地減少整個訓練數據集的損失。 在隨機梯度下降中,梯度將學習參數向最小化單個示例損失的方向移動。 隨着批次的大小從 1 增加到整個數據集的大小,從小批量估計的梯度變得更接近整個數據集的梯度。

現在,考慮到整個數據集不精確,隨機梯度下降是否有用? 絕對地。 事實上,這個估計中的噪聲對於優化中的 escaping 局部最小值很有用。 類似地,您對整個數據集的損失估計中的任何噪聲都可能無需擔心。

b.)但是接下來讓我們看看為什么會發生這種行為。 RMSE 定義為: 在此處輸入圖像描述

其中N是數據集中示例的總數。 如果以這種方式計算 RMSE,我們預計該值大致相同(並且隨着N變大接近完全相同的值)。 但是,在您的情況下,您實際上將平均歷元損失計算為:

在此處輸入圖像描述

其中B是每個 epoch 的 minibatch 數, b是每個 minibatch 的示例數:

在此處輸入圖像描述

因此,epoch loss 是每個 minibatch 的平均 RMSE。 重新排列,我們可以看到:

在此處輸入圖像描述

B很大 ( B = N ) 並且小批量大小為 1 時,

在此處輸入圖像描述

它顯然具有與上面定義的 RMSE 完全不同的屬性。 但是,隨着B變小B = 1 ,並且小批量大小為N

在此處輸入圖像描述

這正好等於上面的 RMSE。 因此,當您增加批量大小時,您計算的數量的預期值會在這兩個表達式之間移動。 這解釋了使用不同 minibatch 大小的損失的(大致平方根)縮放。 Epoch loss 是對 RMSE 的估計(可以認為是 model 預測誤差的標准差)。 一個訓練目標可能是將此錯誤標准差推至零,但您對 epoch loss 的表達也可能是一個很好的代理。 這兩個數量本身就是您實際希望獲得的任何 model 性能的代理。

c。 你可以用一個瑣碎的玩具問題自己試試這個。 正態分布用作 model 錯誤的代理。

示例 1:計算整個數據集的 RMSE(大小為10000 xb

import torch
for b in [1,2,3,5,9,10,100,1000,10000,100000]:
  b_errors = []
  for i in range (10000):
    error = torch.normal(0,100,size = (1,b))
    error = error **2
    error = error.mean()
    b_errors.append(error)

RMSE = torch.sqrt(sum(b_errors)/len(b_errors))
print("Average RMSE for b = {}: {}".format(N,RMSE))

結果:

Average RMSE for b = 1: 99.94982147216797
Average RMSE for b = 2: 100.38357543945312
Average RMSE for b = 3: 100.24600982666016
Average RMSE for b = 5: 100.97154998779297
Average RMSE for b = 9: 100.06820678710938
Average RMSE for b = 10: 100.12358856201172
Average RMSE for b = 100: 99.94219970703125
Average RMSE for b = 1000: 99.97941589355469
Average RMSE for b = 10000: 100.00338745117188

示例 2:計算B = 10000 的 Epoch 損失

import torch
for b in [1,2,3,5,9,10,100,1000,10000,100000]:

b_errors = []
for i in range (10000):
    error = torch.normal(0,100,size = (1,b))
    error = error **2
    error = error.mean()
    error = torch.sqrt(error)
    b_errors.append(error)

avg = (sum(b_errors)/len(b_errors)
print("Average Epoch Loss for b = {}: {}".format(b,avg))

結果:

Average Epoch Loss for b = 1: 80.95650482177734
Average Epoch Loss for b = 2: 88.734375
Average Epoch Loss for b = 3: 92.08515930175781
Average Epoch Loss for b = 5: 95.56260681152344
Average Epoch Loss for b = 9: 97.49445343017578
Average Epoch Loss for b = 10: 97.20250701904297
Average Epoch Loss for b = 100: 99.6297607421875
Average Epoch Loss for b = 1000: 99.96969604492188
Average Epoch Loss for b = 10000: 99.99618530273438
Average Epoch Loss for b = 100000: 100.00079345703125

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM