简体   繁体   中英

Loading a modified pretrained model using strict=False in PyTorch

I want to use a pretrained model as the encoder part in my model. You can find a version of my model:

class MyClass(nn.Module):
    def __init__(self, pretrained=False):
        super(MyClass, self).__init__()
        self.encoder=S3D_featureExtractor_multi_output()
        if pretrained:

        weight_dict=torch.load(os.path.join('models','weights.pt'))

        model_dict=self.encoder.state_dict()
        
        list_weight_dict=list(weight_dict.items())
        list_model_dict=list(model_dict.items())
        
        for i in range(len(list_model_dict)):
            assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
            model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
        
        for i in range(len(list_model_dict)):
            assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
        print('Loading finished!')
                
def forward(self, x):
    
    a, b = self.encoder(x)
    return a, b

Because I modified some parts of the code of this pretrained model, based on this post I need to apply strict=False to avoid facing error, but based on the scenario that I load the pretrained weights, I cannot find a place in the code to apply strict=False. How can I apply that or how can I change the scenario of loading the pretrained model taht makes it possible to apply strict=False ?

strict = False is to specify when you use load_state_dict() method. state_dict are just Python dictionaries that helps you save and load model weights. (for more details, see https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html )

If you use strict=False in load_state_dict , you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. (see https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict )

So, you will need to specify the strict argument when you load the pretrained model weights. load_state_dict can be called at this step. If the model for which weights must be loaded is self.encoder and if state_dict can be retrieved from the model you just loaded, you can just do this

loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)

for more details and a tutorial, see https://pytorch.org/tutorials/beginner/saving_loading_models.html .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM