简体   繁体   English

如何加载部分预训练的 pytorch model?

[英]How can I load a partial pretrained pytorch model?

I'm trying to get a pytorch model running on a sentence classification task.我正在尝试让 pytorch model 在句子分类任务上运行。 As I am working with medical notes I am using ClinicalBert ( https://github.com/kexinhuang12345/clinicalBERT ) and would like to use its pre-trained weights.在处理医学笔记时,我正在使用 ClinicalBert ( https://github.com/kexinhuang12345/clinicalBERT ) 并希望使用其预先训练的权重。 Unfortunately the ClinicalBert model only classifies text into 1 binary label while I have 281 binary labels.不幸的是,ClinicalBert model 仅将文本分类为 1 个二进制 label 而我有 281 个二进制标签。 I am therefore trying to implement this code https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb where the end classifier after bert is 281 long.因此,我正在尝试实现此代码https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb ,其中 bert 之后的最终分类器长度为 281。

How can I load the pre-trained Bert weights from the ClinicalBert model without loading the classification weights?如何在不加载分类权重的情况下从 ClinicalBert model 加载预训练的 Bert 权重?

Naively trying to load the weights from the pretrained ClinicalBert weights I get the following error:天真地尝试从预训练的 ClinicalBert 权重中加载权重,我收到以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

I currently tried to replace the from_pretrained function from the pytorch_pretrained_bert package and pop the classifier weights and biases like this:我目前尝试从 pytorch_pretrained_bert package 替换 from_pretrained function 并像这样弹出分类器权重和偏差:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

And I get the following error message: INFO - modeling_diagnosis - Weights of BertForMultiLabelSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']我收到以下错误消息: INFO - modeling_diagnosis - BertForMultiLabelSequenceClassification 的权重未从预训练的 model 初始化:['classifier.weight', 'classifier.bias']

In the end I would like to load the bert embeddings from the clinicalBert pretrained weights and have the top classifier weights initialized randomly.最后,我想从clinicalBert预训练权重中加载bert嵌入,并随机初始化顶级分类器权重。

Removing the keys in the state dict before loading is a good start.在加载之前删除 state 字典中的键是一个好的开始。 Assuming you're using nn.Module.load_state_dict to load the pretrained weights then you'll also need to set the strict=False argument to avoid errors from unexpected or missing keys.假设您使用nn.Module.load_state_dict加载预训练的权重,那么您还需要设置strict=False参数以避免意外或丢失键导致的错误。 This will ignore entries in the state_dict that aren't present in the model (unexpected keys) and, more importantly for you, will leave the missing entries with their default initialization (missing keys).这将忽略 model 中不存在的 state_dict 条目(意外键),并且对您来说更重要的是,将使用默认初始化(丢失键)保留丢失的条目。 For safety you can check the return value of the method to verify the weights in question are part of the missing keys and that there aren't any unexpected keys.为了安全起见,您可以检查方法的返回值,以验证有问题的权重是丢失键的一部分,并且没有任何意外键。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM