簡體   English   中英

Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d 的區別?

[英]Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d?

我有一個在 Keras 和 PyTorch 中實現的示例微型 CNN。 當我打印兩個網絡的摘要時,可訓練參數的總數相同但參數總數和批量標准化的參數數量不匹配。

這是 Keras 中的 CNN 實現:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

為上述模型打印的摘要是:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

以下是 PyTorch 中相同模型架構的實現:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

以下是上述模型的摘要輸出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

正如您在上面的結果中看到的,Keras 中的 Batch Normalization 比 PyTorch 具有更多的參數(准確地說是 2 倍)。 那么上面的CNN架構有什么不同呢? 如果它們是等效的,那么我在這里缺少什么?

Keras 將許多將在層中“保存/加載”的內容視為參數(權重)。

雖然這兩種實現自然具有批次的累積“均值”和“方差”,但這些值無法通過反向傳播進行訓練。

盡管如此,這些值每批次都會更新,Keras 將它們視為不可訓練的權重,而 PyTorch 只是將它們隱藏起來。 這里的術語“不可訓練”是指“不可通過反向傳播訓練”,但並不意味着值被凍結。

總的來說,它們是BatchNormalization層的 4 組“權重”。 考慮選定的軸(默認 = -1,層大小 = 32)

  • scale (32) - 可訓練
  • offset (32) - 可訓練
  • accumulated means (32) - 不可訓練,但每批次更新
  • accumulated std (32) - 不可訓練,但每批次更新

在 Keras 中這樣做的好處是,當您保存圖層時,您還可以像自動保存圖層中的所有其他權重一樣保存均值和方差值。 當您加載圖層時,這些權重會一起加載。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM