简体   繁体   English

运行时错误:mat1 dim 1 必须匹配 mat2 dim 0

[英]Runtime error: mat1 dim 1 must match mat2 dim 0

I am running one classification program using GaborNet.我正在使用 GaborNet 运行一个分类程序。 Part of my code is我的部分代码是

class module(nn.Module):
        def __init__(self):
            super(module, self).__init__()
            self.g0 = modConv2d(in_channels=3, out_channels=32, kernel_size=(11, 11), stride=1)
            self.c1 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(2, 2),stride=1)
            self.c2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(2, 2),stride=1)

        
            #x = x.view(x.size(0), -1)
            
            #x = x.view(1, *x.shape)
            #x=x.view(-1,512*12*12)
          
            x = F.relu(self.fc1(x))
            print(x.shape)
            x = F.relu(self.fc2(x))
            print(x.shape)
            x = self.fc3(x)
            return x 

I am getting this error at this position : x = F.relu(self.fc1(x)我在这个位置收到这个错误:x = F.relu(self.fc1(x)

and the error is : RuntimeError: mat1 dim 1 must match mat2 dim 0错误是: RuntimeError: mat1 dim 1 must match mat2 dim 0

However the shape of the input image in the subsequent layers are till fc1 is:然而,后续层中输入图像的形状直到 fc1 为:

torch.Size([64, 3, 150, 150])
torch.Size([64, 32, 140, 140])
torch.Size([64, 32, 70, 70])

You were on the right track, you indeed need to reshape your data just after the convolution layers, and before proceeding with the fully connected layers.您走在正确的轨道上,您确实需要在卷积层之后以及在继续使用完全连接的层之前重塑您的数据。

The best approach for flattening your tensor is to usenn.Flatten otherwise you might end up disrupting the batch size.展平张量的最佳方法是使用nn.Flatten ,否则您最终可能会破坏批量大小。 The last spatial layer outputs a shape of (64, 128, 3, 3) and once flattened this tensor has a shape of (64, 1152) where 1152 = 128*3*3 .最后一个空间层输出(64, 128, 3, 3)的形状,一旦展平,此张量的形状为(64, 1152) ,其中1152 = 128*3*3 Therefore your first fully connected layer should have 1152 neurons.因此,您的第一个全连接层应该有1152个神经元。

Something like this should work:像这样的东西应该工作:

class GaborNN(nn.Module):
    def __init__(self):
        super().__init__()
        ...
        self.fc1 = nn.Linear(in_features=1152, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=7)
        self.flatten = nn.Flatten()

    def forward(self, x):
        ...
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM