繁体   English   中英

在 PyTorch 中使用 strict=False 加载修改后的预训练 model

[英]Loading a modified pretrained model using strict=False in PyTorch

我想在我的 model 中使用预训练的 model 作为编码器部分。 你可以找到我的 model 的一个版本:

class MyClass(nn.Module):
    def __init__(self, pretrained=False):
        super(MyClass, self).__init__()
        self.encoder=S3D_featureExtractor_multi_output()
        if pretrained:

        weight_dict=torch.load(os.path.join('models','weights.pt'))

        model_dict=self.encoder.state_dict()
        
        list_weight_dict=list(weight_dict.items())
        list_model_dict=list(model_dict.items())
        
        for i in range(len(list_model_dict)):
            assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
            model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
        
        for i in range(len(list_model_dict)):
            assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
        print('Loading finished!')
                
def forward(self, x):
    
    a, b = self.encoder(x)
    return a, b

因为我修改了这个预训练的 model 的部分代码,基于这篇文章我需要应用strict=False以避免遇到错误,但是基于我加载预训练权重的场景,我无法在代码中找到一个地方应用严格=假。 我该如何应用它或如何更改加载预训练的 model 的场景,这使得应用strict=False成为可能?

strict = False是指定何时使用load_state_dict()方法。 state_dict只是 Python 字典,可帮助您保存和加载 model 权重。 (有关更多详细信息,请参阅https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html

If you use strict=False in load_state_dict , you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. (参见https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

因此,您需要在加载预训练的 model 权重时指定 strict 参数。 load_state_dict可以在这一步调用。 如果必须加载权重的 model 是self.encoder并且如果可以从刚刚加载的state_dict中检索到 state_dict,则可以这样做

loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)

有关更多详细信息和教程,请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM