![](/img/trans.png)
[英]Runtime Error: mat1 and mat2 shapes cannot be multiplied in pytorch
[英]Pytorch error mat1 and mat2 shapes cannot be multiplied
我收到此錯誤。 而我的輸入圖像的大小是 [3072,2,2],所以我通過以下代碼將圖像展平,但是,我收到了這個錯誤:
mat1 and mat2 shapes cannot be multiplied (6144x2 and 12288x512)
我的代碼:
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(12288 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x,0)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
model = NeuralNet().to(device)
# Train the model
total_step = len(my_dataloader)
for epoch in range(5):
for i, (images, labels) in enumerate(my_dataloader):
# Move tensors to the configured device
images = images.to(device)
print(type(images))
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
首先,您的線性層的特征不正確。 功能內應該是您輸入的最后一個暗淡。 在這種情況下,它應該是nn.Linear(2,512)
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(2 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x,0)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
基於PyTorch 文檔, torch.flatten(x,0)
返回形狀為[3072*2,2]
如果您希望在線性特征中具有[12288]
的形狀,則應使用torch.flatten(input, start_dim=0, end_dim=- 1)
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(12288 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x)
out = self.fc1(x,start_dim=0, end_dim=- 1)
out = self.relu(out)
out = self.fc2(out)
return out
你知道發生了哪個錯誤嗎?
mat1 and mat2 shapes cannot be multiplied (6144x2 and 12288x512)
您不能將(mxn)
矩陣與(pxn)
相乘。
your error : (6144x2) * (12288x512)
它必須是(mxn)
和(nxp)
。 這個“內部”維度需要相同(左矩陣的列數 = 右矩陣的行數)。
然后:---> out = torch.flatten(x,0) 將圖像 [3072,2,2] 更改為[3072*2,2] = [6144,2] (不是這個 [16288]),
和矩陣 [6144,2] 和 [2,512] 形狀可以相乘
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.