繁体   English   中英

迁移学习后如何查看预训练模型的权重?

[英]How to view weights of a pretrained model after transfer learning?

我正在关注此处解释的迁移学习教程,并在训练后使用以下方法保存了权重:

torch.save(vgg_based.state_dict(), 'model1.pth')

当我尝试像这样加载模型时:

model = torchvision.models.vgg19()
model.load_state_dict(torch.load('model1.pth'))
model.eval()

我收到以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-423f6125f9e6> in <module>
      5 # features.extend([torch.nn.Linear(number_features, len(class_names))])
      6 # model.classifier = torch.nn.Sequential(*features)
----> 7 model.load_state_dict(torch.load('model1.pth'))
      8 model.eval()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1222         if len(error_msgs) > 0:
   1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1226 

RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for classifier.6.weight: copying a param with shape torch.Size([2, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096]).
    size mismatch for classifier.6.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1000])

我理解这是因为我需要编辑网络以反映输出数量,因此在加载模型之前,我应该在训练期间遵循相同的程序,因此加载模型如下:

model = torchvision.models.vgg19()
#modify the last layers
number_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1]  #remove the last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
model.classifier = torch.nn.Sequential(*features)
model.load_state_dict(torch.load('model1.pth'))
model.eval()

这个对吗? 如何查看权重值并确认我已正确加载新训练的模型?

您已更改最后一个线性层,因此在当前模型上加载状态字典时会出现形状不匹配。

  • 您可以通过在load_state_dict strict 参数设置为 False 来将其设为警告而不是引发错误来丢弃它:

     state = torch.load('model1.pth') diff = model.load_state_dict(state, strict=False) print(diff)

    diff将为您提供所有不匹配条目的列表。

  • 或者在加载之前从字典中完全弹出相关层:

     state = torch.load('model1.pth') state.pop('classifier.6.weight') state.pop('classifier.6.bias') model.load_state_dict(state)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM