简体   繁体   中英

Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d?

I've a sample tiny CNN implemented in both Keras and PyTorch. When I print summary of both the networks, the total number of trainable parameters are same but total number of parameters and number of parameters for Batch Normalization don't match.

Here is the CNN implementation in Keras:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

And the summary printed for above model is:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

Here's the implementation of the same model architecture in PyTorch:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

And following is the output of summary of the above model:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

As you can see in above results, Batch Normalization in Keras has more number of parameters than PyTorch (2x to be exact). So what's the difference in above CNN architectures? If they are equivalent, then what am I missing here?

Keras treats as parameters (weights) many things that will be "saved/loaded" in the layer.

While both implementations naturally have the accumulated "mean" and "variance" of the batches, these values are not trainable with backpropagation.

Nevertheless, these values are updated every batch, and Keras treats them as non-trainable weights, while PyTorch simply hides them. The term "non-trainable" here means "not trainable by backpropagation ", but doesn't mean the values are frozen.

In total they are 4 groups of "weights" for a BatchNormalization layer. Considering the selected axis (default = -1, size=32 for your layer)

  • scale (32) - trainable
  • offset (32) - trainable
  • accumulated means (32) - non-trainable, but updated every batch
  • accumulated std (32) - non-trainable, but updated every batch

The advantage of having it like this in Keras is that when you save the layer, you also save the mean and variance values the same way you save all other weights in the layer automatically. And when you load the layer, these weights are loaded together.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM