I have trained a ResNet50 model on intel image multiclass classification task. The task is trying to predict an image whether it is a building a street or glacier etc. The model is succesfully trained and able to make prediction. I have save the model and trying to use the saved model on new image. Here is the code on training
import os
import torch
import tarfile
import torchvision
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import random_split
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_url
import PIL
import PIL.Image
import numpy as np
transform_train=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
transform_test=transforms.Compose([
transforms.Resize((150,150)),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
...
torch.save(model2.state_dict(),'/content/drive/MyDrive/saved_model/model_resnet.pth')
When I called the model in other files, I use similar image transformation, however it gives me an error, here is the code and the error
model = torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth')
image=Image.open(Path('/content/drive/MyDrive/images/seg_pred/seg_pred/10004.jpg'))
transform_train=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
input = transform_train(image)
#input = input.view(1, 3, 150,150)
output = model(input)
prediction = int(torch.max(output.data, 1)[1].numpy())
print(prediction)
The error that gives me is
TypeError: 'collections.OrderedDict' object is not callable
My pytorch version is 1.9.0+cu102
You need to create the structure of the model first, it's similar to create model2
on your training code, it can be like:
model = resnet()
Then load the saved state dict:
model.load_state_dict(torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth'))
model.eval()
Ref:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.