繁体   English   中英

使用 Pytorch 的 CNN 线性回归:输入和目标形状不匹配:输入 [400 x 1],目标 [200 x 1]

[英]Linear Regression with CNN using Pytorch: input and target shapes do not match: input [400 x 1], target [200 x 1]

让我先解释一下目标。 假设我有 1000 张图像,每张图像都有相关的质量分数 [范围为 0-10]。 现在,我正在尝试使用带有回归的 CNN(在 PyTorch 中)执行图像质量评估。 我已将图像分成相同大小的补丁。 现在,我创建了一个 CNN 网络来执行线性回归。

以下是代码:

class MultiLabelNN(nn.Module):
    def __init__(self):
        super(MultiLabelNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(3200,1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(-1, 3200)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x  

运行此网络代码时,出现以下错误

输入和目标形状不匹配:输入 [400 x 1],目标 [200 x 1]

目标形状是 [200x1] 是因为我的批量大小为 200。我找到了解决方案,如果我更改“self.fc1 = nn.Linear(3200,1024)”和“x = x.view(-1 , 3200)" 在这里从 3200 到 6400 我的代码运行没有任何错误。

同样,如果我输入 12800 而不是 6400,它会抛出错误输入和目标形状不匹配:输入 [100 x 1],目标 [200 x 1]

现在我怀疑是我无法理解这背后的原因。 如果我给我的网络提供 200 张图像作为输入,那么为什么当我从卷积层移动到全连接层时,在更改参数时输入形状会受到影响。 我希望我已经清楚地提到了我的怀疑。 即使我有人有任何疑问,请问我。 这将是一个很大的帮助。 提前致谢。

class MultiLabelNN(nn.Module):
    def __init__(self):
        super(MultiLabelNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(6400,1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)

   def forward(self, x):
       #shape of x is (b_s, 32,32,1)
       x = self.conv1(x) #shape of x is (b_s, 28,28,132)
       x = F.relu(x)
       x = self.pool(x) #shape of x now becomes (b_s X 14 x 14 x 32)
       x = self.conv2(x) # shape(b_s, 10x10x64)
       x = F.relu(x)#size is (b_s x 10 x 10 x 64)
       x = x.view(-1, 3200) # shape of x is now(b_s*2, 3200)
       #this is the problem 
       #you can fc1 to be of shape (6400,1024) and that will work 
       x = self.fc1(x)
       x = F.relu(x)
       x = self.fc2(x)
       x = F.relu(x)
       x = self.fc3(x)
       return x  

我认为这应该有效。 如果仍然存在一些错误,请告诉我。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM