簡體   English   中英

如何使用 PyTorch model 進行預測?

[英]How do I predict using a PyTorch model?

我創建了一個 pyTorch Model 來對圖像進行分類。 我通過 state_dict 和整個 model 保存了一次,如下所示:

torch.save(model.state_dict(), "model1_statedict")
torch.save(model, "model1_complete")

我如何使用這些模型? 我想用一些圖像來檢查它們,看看它們是否良好。

我正在加載 model:

model = torch.load(path_model)
model.eval()

這很好用,但我不知道如何使用它來預測新圖片。

def predict(self, test_images):
    self.eval()
    # model is self(VGG class's object)
    
    count = test_images.shape[0]
    result_np = []
        
    for idx in range(0, count):
        # print(idx)
        img = test_images[idx, :, :, :]
        img = np.expand_dims(img, axis=0)
        img = torch.Tensor(img).permute(0, 3, 1, 2).to(device)
        # print(img.shape)
        pred = self(img)
        pred_np = pred.cpu().detach().numpy()
        for elem in pred_np:
            result_np.append(elem)
    return result_np

網絡是 VGG-19 並參考我的源代碼。

像這樣的架構:

class VGG(object):
    def __init__(self):
    ...


    def train(self, train_images, valid_images):
        train_dataset = torch.utils.data.Dataset(train_images)
        valid_dataset = torch.utils.data.Dataset(valid_images)

        trainloader = torch.utils.data.DataLoader(train_dataset)
        validloader = torch.utils.data.DataLoader(valid_dataset)

        self.optimizer = Adam(...)
        self.criterion = CrossEntropyLoss(...)
    
        for epoch in range(0, epochs):
            ...
            self.evaluate(validloader, model=self, criterion=self.criterion)
    ...

    def evaluate(self, dataloader, model, criterion):
        model.eval()
        for i, sample in enumerate(dataloader):
    ...

    def predict(self, test_images):
    
    ...

if __name__ == "__main__":
    network = VGG()
    trainset, validset = get_dataset()    # abstract function for showing
    testset = get_test_dataset()
    
    network.train(trainset, validset)

    result = network.predict(testset)

pytorch model 是 function。 您為其提供適當定義的輸入,它會返回 output。 如果您只想在給定特定輸入圖像的情況下目視檢查 output,只需調用它:

model.eval()
output = model(example_image)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM