簡體   English   中英

在 cuda 上運行時出現 PyTorch 運行時錯誤

[英]PyTorch Runtime error while running on cuda

我正在嘗試對 GPU 上的數據進行推理。 該代碼在 CPU 上運行良好,但在 GPU 上失敗並出現以下錯誤:

RuntimeError: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, 
but expected them to be on GPU (while checking arguments for addmm) 

這是模型定義:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # encoder
        self.enc1 = nn.Linear(in_features=672, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)

        self.enc1 = nn.ModuleList([self.enc1])
        self.enc2 = nn.ModuleList([self.enc2])
        self.enc3 = nn.ModuleList([self.enc3])
        self.enc4 = nn.ModuleList([self.enc4])
        self.enc5 = nn.ModuleList([self.enc5])

    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
        return x

我使用它的地方:

def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device    

def test_image_reconstruction(net, testloader):
    final_output = []
    for batch in testloader:
        img, label = batch
        img = img.to(device)
        net = net.to(device)
        outputs = net(img)
        loss = criterion(outputs, img)
    
    return final_output

df = pd.read_parquet(data_path)

net = torch.load(model_path)
criterion = nn.MSELoss()
device = get_device()

final_output = test_image_reconstruction(net, valloader)

看起來錯誤來自輸出 = net(img) 部分。 你知道我該如何解決嗎?

您的網絡參數位於 CPU,將其移至 GPU:

net = net.to(device)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM