[英]How to view weights of a pretrained model after transfer learning?
我正在关注此处解释的迁移学习教程,并在训练后使用以下方法保存了权重:
torch.save(vgg_based.state_dict(), 'model1.pth')
当我尝试像这样加载模型时:
model = torchvision.models.vgg19()
model.load_state_dict(torch.load('model1.pth'))
model.eval()
我收到以下错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-423f6125f9e6> in <module>
5 # features.extend([torch.nn.Linear(number_features, len(class_names))])
6 # model.classifier = torch.nn.Sequential(*features)
----> 7 model.load_state_dict(torch.load('model1.pth'))
8 model.eval()
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1222 if len(error_msgs) > 0:
1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224 self.__class__.__name__, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
1226
RuntimeError: Error(s) in loading state_dict for VGG:
size mismatch for classifier.6.weight: copying a param with shape torch.Size([2, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096]).
size mismatch for classifier.6.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1000])
我理解这是因为我需要编辑网络以反映输出数量,因此在加载模型之前,我应该在训练期间遵循相同的程序,因此加载模型如下:
model = torchvision.models.vgg19()
#modify the last layers
number_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1] #remove the last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
model.classifier = torch.nn.Sequential(*features)
model.load_state_dict(torch.load('model1.pth'))
model.eval()
这个对吗? 如何查看权重值并确认我已正确加载新训练的模型?
您已更改最后一个线性层,因此在当前模型上加载状态字典时会出现形状不匹配。
您可以通过在load_state_dict
strict 参数设置为 False 来将其设为警告而不是引发错误来丢弃它:
state = torch.load('model1.pth') diff = model.load_state_dict(state, strict=False) print(diff)
diff
将为您提供所有不匹配条目的列表。
或者在加载之前从字典中完全弹出相关层:
state = torch.load('model1.pth') state.pop('classifier.6.weight') state.pop('classifier.6.bias') model.load_state_dict(state)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.