简体   繁体   中英

Pass an arbitrary image size to cnn in pytorch

I'm trying to train a lenet model in pytorch, The ideia is to put images of any size in it, so I started doing with nn.AdaptiveAvgPool2d but the error comes as

mat1 dim 1 must match mat2 dim 0

Here is my code

class LeNet5(nn.Module):
  def __init__(self, num_classes=10):
    super(LeNet5, self).__init__()

    self.conv_1 = nn.Conv2d(
        in_channels=1, out_channels=32, kernel_size=5, bias=False
    )
    self.relu_1 = nn.ReLU(inplace=True)
    self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv_2 = nn.Conv2d(
        in_channels=32, out_channels=256, kernel_size=5, bias=False
    )
    self.relu_2 = nn.ReLU(inplace=True)
    self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
    self.flatten = nn.Flatten()
    self.fc_1 = nn.Linear(in_features=4096, out_features=120, bias=False)
    self.fc_2 = nn.Linear(in_features=120, out_features=84)
    self.fc_3 = nn.Linear(in_features=84, out_features=num_classes)

  def forward(self, input):
    conv_1_output = self.conv_1(input)
    relu_1_output = self.relu_1(conv_1_output)
    maxpool_1_output = self.maxpool_1(relu_1_output)
    conv_2_output = self.conv_2(maxpool_1_output)
    relu_2_output = self.relu_2(conv_2_output)
    maxpool_2_output = self.maxpool_2(relu_2_output)
    flatten_output = self.flatten((self.avgpool(maxpool_2_output).view(maxpool_2_output.shape[0], -1)))
    fc_1_output = self.fc_1(flatten_output)
    fc_2_output = self.fc_2(fc_1_output)
    fc_3_output = self.fc_3(fc_2_output)

    return fc_3_output

if you read the theory on AdaptiveAvgPool2d , this is what it says " we specify the output size And the stride and kernel-size are automatically selected to adapt to the needs"
More info available here
Hence Your spatial dimension is reduced by AdaptiveAvgPool2d and not the depth of feature maps. So, the spatial dimension will be 1x1 and depth will still be 256 , making your self.fc_1 = nn.Linear(in_features=256, out_features=120, bias=False) and not self.fc_1 = nn.Linear(in_features=4096, out_features=120, bias=False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM