简体   繁体   English

了解 CNN 中的 batch_size

[英]Understanding batch_size in CNNs

Say that I have a CNN model in Pytorch and 2 inputs of the following sizes:假设我在 Pytorch 中有一个 CNN 模型和以下大小的 2 个输入:

  • input_1: [2, 1, 28, 28] input_1: [2, 1, 28, 28]
  • input_2: [10, 1, 28, 28] input_2: [10, 1, 28, 28]

Notes :注意事项

  • To reiterate, input_1 is batch_size == 2 and input_2 is batch_size == 10.重申一下,input_1 是batch_size == 2,input_2 是batch_size == 10。
  • Input_2 is a superset of input_1 . Input_2input_1的超集。 That is, input_2 contains the 2 images in input_1 in the same position.input_2包含input_1中相同位置的 2 张图像。

My question is : how does the CNN process the images in both inputs?我的问题是:CNN 如何处理两个输入中的图像? Ie does the CNN process every image in the batch sequentially?即 CNN 是否按顺序处理批次中的每个图像? Or does it concatenate all of the images in the batch size and then perform convolutions per usual?或者它是否将批量大小的所有图像连接起来,然后按照通常的方式执行卷积?

The reason I ask is because:我问的原因是因为:

  • The output of CNN(input_1) != CNN(input_2)[:2] CNN(input_1) != CNN(input_2)[:2] 的输出

That is, the difference in batch_size results in slightly different CNN outputs for both inputs for the same positions .也就是说,batch_size 的差异导致相同位置的两个输入的 CNN 输出略有不同。

CNN is a general term for convolutional neural networks. CNN是卷积神经网络的总称。 Depending on the particular architecture it may do different things.根据特定的架构,它可能会做不同的事情。 The main building blocks of CNNs are convolutions which do not cause any "crosstalk" between items in batch and pointwise nonlinearities like ReLU which do not either. CNN 的主要构建块是卷积,它不会在批处理项目之间引起任何“串扰”,而像ReLU这样的逐点非线性也不会引起任何“串扰”。 However, most architectures also involve other operations, such as normalization layers - arguably the most popular is batch norm which does introduce crosstalk.然而,大多数架构还涉及其他操作,例如归一化层——可以说最流行的是批处理规范,它确实引入了串扰。 Many models will also use dropout which behaves stochastically outside of eval mode (by default models are in train mode).许多模型还将使用dropout ,它在eval模式之外随机表现(默认模型处于 训练模式)。 Both above effects could lead to the observed outcome above, as well as other custom operations which could cause cross-talk across the batch.上述两种影响都可能导致上述观察到的结果,以及可能导致批次间串扰的其他自定义操作。

Aside from that, because of numeric precision issues, your code may not give exactly the same results, even if it doesn't feature any cross-batch operations.除此之外,由于 数字精度问题,您的代码可能不会给出完全相同的结果,即使它没有任何跨批处理操作。 This error is very minor but sufficient to manifest itself when checking with CNN(input_1) == CNN(input_2)[:2] .这个错误非常小,但足以在检查CNN(input_1) == CNN(input_2)[:2] It is better to useallclose instead, with a suitable epsilon.最好使用allclose代替,并带有合适的 epsilon。

Just to add to Jatentaki's nice answer, below is a quick demonstration of the fact that the pure conv2d doesn't introduce "crosstalk" between items in a batch:只是为了添加 Jatentaki 的好答案,下面是一个事实的快速演示,即纯 conv2d 不会在批次中的项目之间引入“串扰”:

import torch
import torch.nn.functional as F


input_1 = torch.randn((10, 1, 28, 28))
input_2 = input_1[:2]

weight =  torch.randn((16, 1, 3, 3))

conv_1 = F.conv2d(input_1, weight)
conv_2 = F.conv2d(input_2, weight)

>>> torch.equal(conv_1[:2], conv_2)
True

So the reason for the discrepancy you get is probably one of the mentioned by Jatentaki (if you could show your CNN model it would help to spot the exact reason).因此,您获得差异的原因可能是 Jatentaki 提到的原因之一(如果您可以展示您的 CNN 模型,它将有助于找出确切原因)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM