簡體   English   中英

PyTorch 自動編碼器:mat1 和 mat2 形狀不能相乘(1x512 和 12x64)

[英]PyTorch AutoEncoder : mat1 and mat2 shapes cannot be multiplied (1x512 and 12x64)

我正在嘗試通過自動編碼器傳遞 CNN 的 output 功能。 我使用 hooklayer 來提取 CNN 的特征並將它們轉換為張量。

extracted_features = torch.tensor(rn_output)

元組到張量轉換后的數據大小為torch.Size([1014,512])

AutoEncoder 的解碼器部分拋出“不能相乘錯誤”,但我認為錯誤是由於輸入的設置和形狀造成的。

自動編碼器

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),  # N, 512 -> N,128 
            nn.ReLU(),  # Activation Function
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),  # Activation Function
            nn.Linear(in_features=64, out_features=12),
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_features=12, out_features=64),  # N, 3 -> N,12 
            nn.ReLU(),  # Activation Function
            nn.Linear(in_features=64, out_features=128),
            nn.Linear(in_features=128, out_features=256),
            nn.ReLU(),  # Activation Function
            nn.Linear(in_features=256, out_features=512),
            nn.Tanh()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(x)
        return decoded

調用自動編碼器


model = AutoEncoder()
criterion = nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

num_epochs = 10
outputs = []
for epoch in range(num_epochs): 
    for (img) in extracted_features:
        recon = model(img)
        loss = criterion(recon, img)
        
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
    print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
    outputs.append((epoch, img, recon))

我嘗試使用數據加載器並以較小的批量傳入數據。 我也嘗試過在 forward 方法中重塑圖像,但我仍然繼續得到同樣的錯誤

我很確定您的forward function 錯誤地執行了編碼器-解碼器步驟。 我認為你應該改變它:

encoded = self.encoder(x)
decoded = self.decoder(x)

對此:

encoded = self.encoder(x)
decoded = self.decoder(encoded)

解碼器通常對編碼輸入進行操作,而不是直接對輸入本身進行操作,除非您使用的是我不熟悉的編碼器-解碼器的非標准定義。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM