简体   繁体   English

是什么让 pytorch 中的预训练 model 错误分类图像

[英]What makes a pre-trained model in pytorch misclassify an image

I successfully trained Data Efficient Image Transformer (deit) on cifar-10 dataset with an accuracy of about 95%.我成功地在 cifar-10 数据集上训练了 Data Efficient Image Transformer (deit),准确率约为 95%。 However and saved it for later use.但是并保存以备后用。 I created a separate class to load the model and make inference on just one image.我创建了一个单独的 class 来加载 model 并仅对一张图像进行推断。 I keep getting different value for prediction every time I run it.每次运行它时,我都会得到不同的预测值。

import torch
from models.deit import deit_small_patch16_224

import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)

checkpoint = torch.load("./checkpoint/deit224.t7")
model.load_state_dict(checkpoint, strict=False)
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.eval()


img = Image.open("cats.jpeg")
img_tensor = torch.tensor(np.array(img))/255.0
img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
    output = model(img_tensor)
    predicted_class = np.argmax(output)
    print(predicted_class)

Yes,figured out the error.是的,找出了错误。 updated code below下面更新了代码

import torch
from models.deit import deit_small_patch16_224
from torch.utils.data import dataset
import torchvision.datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision.transforms import transforms as transforms
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

model = deit_small_patch16_224(pretrained=True, use_top_n_heads=8, use_patch_outputs=False)

checkpoint = torch.load("./checkpoint/deit224.t7")
state_dict = checkpoint["model"]
new_state_dict = {}
for key in state_dict:
    new_key = '.'.join(key.split('.')[1:])
    new_state_dict[new_key] = state_dict[key]

model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=10)
model.load_state_dict(new_state_dict)
model.eval()


img = Image.open("cats.jpeg")
trans = transforms.ToTensor()
# img_tensor = torch.tensor(np.array(img, dtype=np.float64))/255.0
img_tensor = torch.tensor(np.array(img))/255.0
# img_tensor = torch.tensor(np.array(img))

img_tensor = img_tensor.unsqueeze(0).permute(0, 3, 1, 2)
# print(img_tensor.shape)
with torch.no_grad():
    output = model(img_tensor)
    predicted_class = np.argmax(output)
    print(predicted_class)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM