簡體   English   中英

如何加載部分預訓練的 pytorch model?

[英]How can I load a partial pretrained pytorch model?

我正在嘗試讓 pytorch model 在句子分類任務上運行。 在處理醫學筆記時,我正在使用 ClinicalBert ( https://github.com/kexinhuang12345/clinicalBERT ) 並希望使用其預先訓練的權重。 不幸的是,ClinicalBert model 僅將文本分類為 1 個二進制 label 而我有 281 個二進制標簽。 因此,我正在嘗試實現此代碼https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb ,其中 bert 之后的最終分類器長度為 281。

如何在不加載分類權重的情況下從 ClinicalBert model 加載預訓練的 Bert 權重?

天真地嘗試從預訓練的 ClinicalBert 權重中加載權重,我收到以下錯誤:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前嘗試從 pytorch_pretrained_bert package 替換 from_pretrained function 並像這樣彈出分類器權重和偏差:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

我收到以下錯誤消息: INFO - modeling_diagnosis - BertForMultiLabelSequenceClassification 的權重未從預訓練的 model 初始化:['classifier.weight', 'classifier.bias']

最后,我想從clinicalBert預訓練權重中加載bert嵌入,並隨機初始化頂級分類器權重。

在加載之前刪除 state 字典中的鍵是一個好的開始。 假設您使用nn.Module.load_state_dict加載預訓練的權重,那么您還需要設置strict=False參數以避免意外或丟失鍵導致的錯誤。 這將忽略 model 中不存在的 state_dict 條目(意外鍵),並且對您來說更重要的是,將使用默認初始化(丟失鍵)保留丟失的條目。 為了安全起見,您可以檢查方法的返回值,以驗證有問題的權重是丟失鍵的一部分,並且沒有任何意外鍵。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM