I'm trying to train a lenet model in pytorch, The ideia is to put images of any size in it, so I started doing with nn.AdaptiveAvgPool2d
but the error comes as
mat1 dim 1 must match mat2 dim 0
Here is my code
class LeNet5(nn.Module):
def __init__(self, num_classes=10):
super(LeNet5, self).__init__()
self.conv_1 = nn.Conv2d(
in_channels=1, out_channels=32, kernel_size=5, bias=False
)
self.relu_1 = nn.ReLU(inplace=True)
self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_2 = nn.Conv2d(
in_channels=32, out_channels=256, kernel_size=5, bias=False
)
self.relu_2 = nn.ReLU(inplace=True)
self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten()
self.fc_1 = nn.Linear(in_features=4096, out_features=120, bias=False)
self.fc_2 = nn.Linear(in_features=120, out_features=84)
self.fc_3 = nn.Linear(in_features=84, out_features=num_classes)
def forward(self, input):
conv_1_output = self.conv_1(input)
relu_1_output = self.relu_1(conv_1_output)
maxpool_1_output = self.maxpool_1(relu_1_output)
conv_2_output = self.conv_2(maxpool_1_output)
relu_2_output = self.relu_2(conv_2_output)
maxpool_2_output = self.maxpool_2(relu_2_output)
flatten_output = self.flatten((self.avgpool(maxpool_2_output).view(maxpool_2_output.shape[0], -1)))
fc_1_output = self.fc_1(flatten_output)
fc_2_output = self.fc_2(fc_1_output)
fc_3_output = self.fc_3(fc_2_output)
return fc_3_output
if you read the theory on AdaptiveAvgPool2d , this is what it says " we specify the output size And the stride and kernel-size are automatically selected to adapt to the needs"
More info available here
Hence Your spatial dimension is reduced by AdaptiveAvgPool2d
and not the depth of feature maps. So, the spatial dimension will be 1x1
and depth will still be 256
, making your self.fc_1 = nn.Linear(in_features=256, out_features=120, bias=False)
and not self.fc_1 = nn.Linear(in_features=4096, out_features=120, bias=False)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.