简体   繁体   English

在 PyTorch 中使用 strict=False 加载修改后的预训练 model

[英]Loading a modified pretrained model using strict=False in PyTorch

I want to use a pretrained model as the encoder part in my model.我想在我的 model 中使用预训练的 model 作为编码器部分。 You can find a version of my model:你可以找到我的 model 的一个版本:

class MyClass(nn.Module):
    def __init__(self, pretrained=False):
        super(MyClass, self).__init__()
        self.encoder=S3D_featureExtractor_multi_output()
        if pretrained:

        weight_dict=torch.load(os.path.join('models','weights.pt'))

        model_dict=self.encoder.state_dict()
        
        list_weight_dict=list(weight_dict.items())
        list_model_dict=list(model_dict.items())
        
        for i in range(len(list_model_dict)):
            assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
            model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
        
        for i in range(len(list_model_dict)):
            assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
        print('Loading finished!')
                
def forward(self, x):
    
    a, b = self.encoder(x)
    return a, b

Because I modified some parts of the code of this pretrained model, based on this post I need to apply strict=False to avoid facing error, but based on the scenario that I load the pretrained weights, I cannot find a place in the code to apply strict=False.因为我修改了这个预训练的 model 的部分代码,基于这篇文章我需要应用strict=False以避免遇到错误,但是基于我加载预训练权重的场景,我无法在代码中找到一个地方应用严格=假。 How can I apply that or how can I change the scenario of loading the pretrained model taht makes it possible to apply strict=False ?我该如何应用它或如何更改加载预训练的 model 的场景,这使得应用strict=False成为可能?

strict = False is to specify when you use load_state_dict() method. strict = False是指定何时使用load_state_dict()方法。 state_dict are just Python dictionaries that helps you save and load model weights. state_dict只是 Python 字典,可帮助您保存和加载 model 权重。 (for more details, see https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html ) (有关更多详细信息,请参阅https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html

If you use strict=False in load_state_dict , you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. If you use strict=False in load_state_dict , you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. (see https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict ) (参见https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

So, you will need to specify the strict argument when you load the pretrained model weights.因此,您需要在加载预训练的 model 权重时指定 strict 参数。 load_state_dict can be called at this step. load_state_dict可以在这一步调用。 If the model for which weights must be loaded is self.encoder and if state_dict can be retrieved from the model you just loaded, you can just do this如果必须加载权重的 model 是self.encoder并且如果可以从刚刚加载的state_dict中检索到 state_dict,则可以这样做

loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)

for more details and a tutorial, see https://pytorch.org/tutorials/beginner/saving_loading_models.html .有关更多详细信息和教程,请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM