简体   繁体   中英

PyTorch loading pretrained weights

I am trying to load a pretrained model re.net_18.pth file into pytorch. Online documentation suggested importing like so:

weights = torch.load("resnet_18.pth")

When I print the output of weights , it gives something like the following:

 ('module.layer4.1.bn2.running_mean', tensor([ 9.1797e+01, -2.4204e+02,  5.6480e+01, -2.0762e+02,  4.5270e+01,
        -3.2356e+02,  1.8662e+02, -1.4498e+02, -2.3701e+02,  3.2354e+01,
...

All of the tutorials mentioned loading weights using a base model:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

I want to use a default re.net-18 model to apply the weights on, but I the resent18 from tensorflow vision does not have the load_state_dict function. Help is appreciated.

from torchvision.models import resnet18
resnet18.load_state_dict(torch.load("resnet_18.pth"))

# 'function' object has no attribute 'load_state_dict'

re.net18 is itself a function that returns a ResNet18 model. What you can do to load your own pretrained weights is to use

model = resnet18()
model.load_state_dict(torch.load("resnet_18.pth"))

Note that load_state_dict(...) loads the weights in-place and does not return model itself.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM