繁体   English   中英

如何加载部分预训练的 pytorch model?

[英]How can I load a partial pretrained pytorch model?

我正在尝试让 pytorch model 在句子分类任务上运行。 在处理医学笔记时,我正在使用 ClinicalBert ( https://github.com/kexinhuang12345/clinicalBERT ) 并希望使用其预先训练的权重。 不幸的是,ClinicalBert model 仅将文本分类为 1 个二进制 label 而我有 281 个二进制标签。 因此,我正在尝试实现此代码https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb ,其中 bert 之后的最终分类器长度为 281。

如何在不加载分类权重的情况下从 ClinicalBert model 加载预训练的 Bert 权重?

天真地尝试从预训练的 ClinicalBert 权重中加载权重,我收到以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前尝试从 pytorch_pretrained_bert package 替换 from_pretrained function 并像这样弹出分类器权重和偏差:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

我收到以下错误消息: INFO - modeling_diagnosis - BertForMultiLabelSequenceClassification 的权重未从预训练的 model 初始化:['classifier.weight', 'classifier.bias']

最后,我想从clinicalBert预训练权重中加载bert嵌入,并随机初始化顶级分类器权重。

在加载之前删除 state 字典中的键是一个好的开始。 假设您使用nn.Module.load_state_dict加载预训练的权重,那么您还需要设置strict=False参数以避免意外或丢失键导致的错误。 这将忽略 model 中不存在的 state_dict 条目(意外键),并且对您来说更重要的是,将使用默认初始化(丢失键)保留丢失的条目。 为了安全起见,您可以检查方法的返回值,以验证有问题的权重是丢失键的一部分,并且没有任何意外键。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM