簡體   English   中英

遷移學習后如何查看預訓練模型的權重?

[英]How to view weights of a pretrained model after transfer learning?

我正在關注此處解釋的遷移學習教程,並在訓練后使用以下方法保存了權重:

torch.save(vgg_based.state_dict(), 'model1.pth')

當我嘗試像這樣加載模型時:

model = torchvision.models.vgg19()
model.load_state_dict(torch.load('model1.pth'))
model.eval()

我收到以下錯誤:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-423f6125f9e6> in <module>
      5 # features.extend([torch.nn.Linear(number_features, len(class_names))])
      6 # model.classifier = torch.nn.Sequential(*features)
----> 7 model.load_state_dict(torch.load('model1.pth'))
      8 model.eval()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1222         if len(error_msgs) > 0:
   1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1226 

RuntimeError: Error(s) in loading state_dict for VGG:
    size mismatch for classifier.6.weight: copying a param with shape torch.Size([2, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096]).
    size mismatch for classifier.6.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1000])

我理解這是因為我需要編輯網絡以反映輸出數量,因此在加載模型之前,我應該在訓練期間遵循相同的程序,因此加載模型如下:

model = torchvision.models.vgg19()
#modify the last layers
number_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1]  #remove the last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
model.classifier = torch.nn.Sequential(*features)
model.load_state_dict(torch.load('model1.pth'))
model.eval()

這個對嗎? 如何查看權重值並確認我已正確加載新訓練的模型?

您已更改最后一個線性層,因此在當前模型上加載狀態字典時會出現形狀不匹配。

  • 您可以通過在load_state_dict strict 參數設置為 False 來將其設為警告而不是引發錯誤來丟棄它:

     state = torch.load('model1.pth') diff = model.load_state_dict(state, strict=False) print(diff)

    diff將為您提供所有不匹配條目的列表。

  • 或者在加載之前從字典中完全彈出相關層:

     state = torch.load('model1.pth') state.pop('classifier.6.weight') state.pop('classifier.6.bias') model.load_state_dict(state)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM