[英]Loading a modified pretrained model using strict=False in PyTorch
I want to use a pretrained model as the encoder part in my model.我想在我的 model 中使用预训练的 model 作为编码器部分。 You can find a version of my model:
你可以找到我的 model 的一个版本:
class MyClass(nn.Module):
def __init__(self, pretrained=False):
super(MyClass, self).__init__()
self.encoder=S3D_featureExtractor_multi_output()
if pretrained:
weight_dict=torch.load(os.path.join('models','weights.pt'))
model_dict=self.encoder.state_dict()
list_weight_dict=list(weight_dict.items())
list_model_dict=list(model_dict.items())
for i in range(len(list_model_dict)):
assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
for i in range(len(list_model_dict)):
assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
print('Loading finished!')
def forward(self, x):
a, b = self.encoder(x)
return a, b
Because I modified some parts of the code of this pretrained model, based on this post I need to apply strict=False
to avoid facing error, but based on the scenario that I load the pretrained weights, I cannot find a place in the code to apply strict=False.因为我修改了这个预训练的 model 的部分代码,基于这篇文章我需要应用
strict=False
以避免遇到错误,但是基于我加载预训练权重的场景,我无法在代码中找到一个地方应用严格=假。 How can I apply that or how can I change the scenario of loading the pretrained model taht makes it possible to apply strict=False
?我该如何应用它或如何更改加载预训练的 model 的场景,这使得应用
strict=False
成为可能?
strict = False
is to specify when you use load_state_dict()
method. strict = False
是指定何时使用load_state_dict()
方法。 state_dict
are just Python dictionaries that helps you save and load model weights. state_dict
只是 Python 字典,可帮助您保存和加载 model 权重。 (for more details, see https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html ) (有关更多详细信息,请参阅https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html )
If you use strict=False
in load_state_dict
, you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. If you use
strict=False
in load_state_dict
, you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. (see https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict ) (参见https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict )
So, you will need to specify the strict argument when you load the pretrained model weights.因此,您需要在加载预训练的 model 权重时指定 strict 参数。
load_state_dict
can be called at this step. load_state_dict
可以在这一步调用。 If the model for which weights must be loaded is self.encoder
and if state_dict
can be retrieved from the model you just loaded, you can just do this如果必须加载权重的 model 是
self.encoder
并且如果可以从刚刚加载的state_dict
中检索到 state_dict,则可以这样做
loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)
for more details and a tutorial, see https://pytorch.org/tutorials/beginner/saving_loading_models.html .有关更多详细信息和教程,请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html 。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.