简体   繁体   中英

How to view weights of a pretrained model after transfer learning?

I'm following a transfer learning tutorial explained here and saved the weights after training using:

torch.save(vgg_based.state_dict(), 'model1.pth')

When I try to load the model like this:

model = torchvision.models.vgg19()
model.load_state_dict(torch.load('model1.pth'))
model.eval()

I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-423f6125f9e6> in <module>
      5 # features.extend([torch.nn.Linear(number_features, len(class_names))])
      6 # model.classifier = torch.nn.Sequential(*features)
----> 7 model.load_state_dict(torch.load('model1.pth'))
      8 model.eval()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1222         if len(error_msgs) > 0:
   1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1226 

RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for classifier.6.weight: copying a param with shape torch.Size([2, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096]).
    size mismatch for classifier.6.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1000])

I understand this is because I need to edit the network to reflect the number of outputs so before I load the model I should follow the same procedure during training, therefore loading the model like this:

model = torchvision.models.vgg19()
#modify the last layers
number_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1]  #remove the last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
model.classifier = torch.nn.Sequential(*features)
model.load_state_dict(torch.load('model1.pth'))
model.eval()

Is this correct? How can I view the values of the weights and confirm I have loaded the newly trained model correctly?

You have changed the last linear layer, so there will be a shape mismatch when loading the state dictionary on your current model.

  • You can either discard this by making it a warning instead of throwing an error by setting the strict argument to False on load_state_dict :

     state = torch.load('model1.pth') diff = model.load_state_dict(state, strict=False) print(diff)

    diff will give you a list of all the mismatched entries.

  • Or by popping the concerning layers out of the dictionary altogether before loading:

     state = torch.load('model1.pth') state.pop('classifier.6.weight') state.pop('classifier.6.bias') model.load_state_dict(state)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM