簡體   English   中英

Pytorch 錯誤 mat1 和 mat2 形狀不能相乘

[英]Pytorch error mat1 and mat2 shapes cannot be multiplied

我收到此錯誤。 而我的輸入圖像的大小是 [3072,2,2],所以我通過以下代碼將圖像展平,但是,我收到了這個錯誤:

mat1 and mat2 shapes cannot be multiplied (6144x2 and 12288x512)

我的代碼:

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(12288 ,512) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 3)  
    
    def forward(self, x):
        out = torch.flatten(x,0)
        
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = NeuralNet().to(device)

# Train the model
total_step = len(my_dataloader)
for epoch in range(5):
    for i, (images, labels) in enumerate(my_dataloader):  
        # Move tensors to the configured device
        images = images.to(device)
        print(type(images))
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

首先,您的線性層的特征不正確。 功能內應該是您輸入的最后一個暗淡。 在這種情況下,它應該是nn.Linear(2,512)

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(2 ,512) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 3)  
    
    def forward(self, x):
        out = torch.flatten(x,0)
        
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

基於PyTorch 文檔torch.flatten(x,0)返回形狀為[3072*2,2]如果您希望在線性特征中具有[12288]的形狀,則應使用torch.flatten(input, start_dim=0, end_dim=- 1)

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(12288 ,512) 
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 3)  
    
    def forward(self, x):
        out = torch.flatten(x)
        
        out = self.fc1(x,start_dim=0, end_dim=- 1)
        out = self.relu(out)
        out = self.fc2(out)
        return out

你知道發生了哪個錯誤嗎?

mat1 and mat2 shapes cannot be multiplied (6144x2 and 12288x512)

您不能將(mxn)矩陣與(pxn)相乘。

your error : (6144x2) * (12288x512)

它必須是(mxn)(nxp) 這個“內部”維度需要相同(左矩陣的列數 = 右矩陣的行數)。

然后:---> out = torch.flatten(x,0) 將圖像 [3072,2,2] 更改為[3072*2,2] = [6144,2] (不是這個 [16288]),

和矩陣 [6144,2] 和 [2,512] 形狀可以相乘

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM