[英]How do I use a .pickle file to predict an image?
我在 PyTorch 中训练了一个 CNN model 来检测 6 个不同类别的皮肤病。 我的 model 的准确率为 92%,我将其保存在一个 .pickle 文件中。 我希望使用此 model 进行预测,但我不知道该怎么做。 如果有人可以在必要的步骤中帮助我,我将不胜感激。 我曾尝试使用 Streamlit,但显然 Streamlit 不再起作用,因此我选择了离线解决方案,我可以只上传图像,model 会给我这样的预测。
这是我的 model 的代码。 我使用了预训练的 ResNet18 model 并在 Kaggle 的 Skin Cancer MNIST: HAM10000 数据集上对其进行了训练。
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
input_size = 0
if model_name == "resnet":
""" Resnet18, resnet34, resnet50, resnet101
"""
model_ft = models.resnet18(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "vgg":
""" VGG11_bn
"""
model_ft = models.vgg11_bn(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
elif model_name == "densenet":
""" Densenet121
"""
model_ft = models.densenet121(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "inception":
""" Inception v3
"""
model_ft = models.inception_v3(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
# Handle the auxilary net
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
# Handle the primary net
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs,num_classes)
input_size = 299
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
# resnet,vgg,densenet,inception
model_name = 'resnet'
num_classes = 7
feature_extract = False
# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
# Define the device:
device = torch.device('cuda:0')
# Put the model on the device:
model = model_ft.to(device)
# norm_mean = (0.49139968, 0.48215827, 0.44653124)
# norm_std = (0.24703233, 0.24348505, 0.26158768)
# define the transformation of the train images.
train_transform = transforms.Compose([transforms.Resize((input_size,input_size)),transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])
# define the transformation of the val images.
val_transform = transforms.Compose([transforms.Resize((input_size,input_size)), transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
# Define a pytorch dataloader for this dataset
class HAM10000(Dataset):
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
# Load data and get label
X = Image.open(self.df['path'][index])
y = torch.tensor(int(self.df['cell_type_idx'][index]))
if self.transform:
X = self.transform(X)
return X, y
# Define the training set using the table train_df and using our defined transitions (train_transform)
training_set = HAM10000(df_train, transform=train_transform)
train_loader = DataLoader(training_set, batch_size=64, shuffle=True, num_workers=4)
# Same for the validation set:
validation_set = HAM10000(df_val, transform=train_transform)
val_loader = DataLoader(validation_set, batch_size=64, shuffle=False, num_workers=4)
# we use Adam optimizer, use cross entropy loss as our loss function
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss().to(device)
这是训练过程和保存文件。
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
total_loss_train, total_acc_train = [],[]
def train(train_loader, model, criterion, optimizer, epoch):
model.train()
train_loss = AverageMeter()
train_acc = AverageMeter()
curr_iter = (epoch - 1) * len(train_loader)
for i, data in enumerate(train_loader):
images, labels = data
N = images.size(0)
# print('image shape:',images.size(0), 'label shape',labels.size(0))
images = Variable(images).to(device)
labels = Variable(labels).to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
prediction = outputs.max(1, keepdim=True)[1]
train_acc.update(prediction.eq(labels.view_as(prediction)).sum().item()/N)
train_loss.update(loss.item())
curr_iter += 1
if (i + 1) % 100 == 0:
print('[epoch %d], [iter %d / %d], [train loss %.5f], [train acc %.5f]' % (
epoch, i + 1, len(train_loader), train_loss.avg, train_acc.avg))
total_loss_train.append(train_loss.avg)
total_acc_train.append(train_acc.avg)
return train_loss.avg, train_acc.avg
def validate(val_loader, model, criterion, optimizer, epoch):
model.eval()
val_loss = AverageMeter()
val_acc = AverageMeter()
with torch.no_grad():
for i, data in enumerate(val_loader):
images, labels = data
N = images.size(0)
images = Variable(images).to(device)
labels = Variable(labels).to(device)
outputs = model(images)
prediction = outputs.max(1, keepdim=True)[1]
val_acc.update(prediction.eq(labels.view_as(prediction)).sum().item()/N)
val_loss.update(criterion(outputs, labels).item())
print('------------------------------------------------------------')
print('[epoch %d], [val loss %.5f], [val acc %.5f]' % (epoch, val_loss.avg, val_acc.avg))
print('------------------------------------------------------------')
return val_loss.avg, val_acc.avg
if os.path.exists("Tested_model2.pickle"):
print("Loading Trained Model")
model = pickle.load(open("Tested_model2.pickle", "rb"))
print(model)
else:
print("Training New Model.")
print("Training begins.")
print("********************************************************")
epoch_num = 25
load_model = True
best_val_acc = 0
total_loss_val, total_acc_val = [],[]
for epoch in range(1, epoch_num+1):
loss_train, acc_train = train(train_loader, model, criterion, optimizer, epoch)
loss_val, acc_val = validate(val_loader, model, criterion, optimizer, epoch)
total_loss_val.append(loss_val)
total_acc_val.append(acc_val)
if acc_val > best_val_acc:
best_val_acc = acc_val
print('*****************************************************')
print('best record: [epoch %d], [val loss %.5f], [val acc %.5f]' % (epoch, loss_val, acc_val))
print('*****************************************************')
with open ("Tested_model2.pickle", "wb") as file:
pickle.dump(model, file)
简单地说,我想知道如何使用 pickle 文件进行预测。
编辑:我在下面添加了评估部分,请帮助我了解如何进一步处理此代码。
model.eval()
y_label = []
y_predict = []
with torch.no_grad():
for i, data in enumerate(val_loader):
images, labels = data
N = images.size(0)
images = Variable(images).to(device)
outputs = model(images)
prediction = outputs.max(1, keepdim=True)[1]
y_label.extend(labels.cpu().numpy())
y_predict.extend(np.squeeze(prediction.cpu().numpy().T))
另外,这是我之前用来加载和预测的代码,但是我不知道代码是否正确或方法是否正确。
%%writefile app.py
import streamlit as st
import torch
st.set_option('deprecation.showfileUploaderEncoding', False)
@st.cache(allow_output_mutation=True)
def load_model():
model = pickle.load(open("Trained_Model_part2.pickle", "rb"))
return mdoel
model = load_model()
st.write("""
#Classification of skin disease
""")
file = st.file_uploader("Please upload the image of the affected area.", type = ["jpg", "png"])
import cv2
from PIL import Image, ImageOps
import numpy as np
def import_and_predict(image_data, model):
size = (224, 224)
image = ImageOps.fit(image_data, size, Image.ANTIALIAS)
img = np.asarray(image)
image_reshape = img[np.newaxis,...]
prediction = model.predict(img_reshape)
return prediction
if file is None:
st.text("Please upload an image file.")
else:
image = Image.open(file)
st.image(image, use_column_width = True)
predictions = import_and_predict(image, model)
class_names = ["Melanocytic nevi", "Melanoma", "Benign keratosis-like lesions", "Basal cell carcinoma", "Actinic keratoses", "Vascular lesions", "Dermatofibroma"]
string = "It is: " + class_names[np.argmax(predictions)]
st.success(string)
这使用 streamlit 并且加载器是以前的 pickle 文件加载器,它将被 a.pth 加载器替换。 我想知道我必须进行哪些更改,以便代码将要求输入图像或在特定文件夹中查找图像并提供预测。 谢谢你。
我将向您展示如何正确保存和加载 pytorch model 参数(您应该使用.pt
扩展名):
要保存 model,请执行以下操作(每个 epoch 或训练后一次):
torch.save(model.state_dict(), "your/path/model_file.pt")
所有 model 参数现在都加载到“your/path/model_file.pt”中。
现在,要加载 model,您将需要 model class YourModel(nn.Module): ...
model = YourModel()
model.load_state_dict(torch.load("your/path/model_file.pt"))
model 现在已使用经过优化的参数进行初始化并可以使用。 例如像这样:
model = YourModel()
model.load_state_dict(torch.load("your/path/model_file.pt"))
# set to evaluation mode
model.eval()
# load an image
sample = get_sample()
# reshape sample to (batch-size x width x height) but batch-size is 1 because you probably want to predict just one image at a time in real-life usage
sample = torch.reshape(1, sample.size(0), sample.size(1))
prediction = model(sample)
编辑以回答评论中的问题:要加载经过训练的 pytorch model 您需要保存模型参数的文件和 model 结构本身。 model 结构只是 pytorch 模块 class 的 python 代码。 您还没有自己构建 model 因此您没有直接的 model 代码,但在您的情况下它应该是model_ft
。 只是 python class 持有所有层。 所以model class就像骨架,参数像肉什么的。
当您完全自己创建 model class 并将训练的权重加载到其中时,它看起来像这样:
import torch
import torch.nn as nn
import torch.nn.functional as F
# the model (skeleton class)
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
self.dense1 = nn.Linear(128, 64)
self.dense2 = nn.Linear(64, 2)
def forward(self, x):
x = F.relu(self.dense1(x))
x = torch.sigmoid(self.dense2(x))
return x
# . . .
# train model and save it to model.pt
# . . .
# load "empty" model
model = YourModel()
# load trained paramters/weights into the model
model.load_state_dict(torch.load("/path/model.pt"))
正如我所说,在您的情况下, model class 应该是model_ft
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.