![](/img/trans.png)
[英]Loading a pretrained model in PyTorch, error:object not callable
[英]Loading a modified pretrained model using strict=False in PyTorch
我想在我的 model 中使用预训练的 model 作为编码器部分。 你可以找到我的 model 的一个版本:
class MyClass(nn.Module):
def __init__(self, pretrained=False):
super(MyClass, self).__init__()
self.encoder=S3D_featureExtractor_multi_output()
if pretrained:
weight_dict=torch.load(os.path.join('models','weights.pt'))
model_dict=self.encoder.state_dict()
list_weight_dict=list(weight_dict.items())
list_model_dict=list(model_dict.items())
for i in range(len(list_model_dict)):
assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
for i in range(len(list_model_dict)):
assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
print('Loading finished!')
def forward(self, x):
a, b = self.encoder(x)
return a, b
因为我修改了这个预训练的 model 的部分代码,基于这篇文章我需要应用strict=False
以避免遇到错误,但是基于我加载预训练权重的场景,我无法在代码中找到一个地方应用严格=假。 我该如何应用它或如何更改加载预训练的 model 的场景,这使得应用strict=False
成为可能?
strict = False
是指定何时使用load_state_dict()
方法。 state_dict
只是 Python 字典,可帮助您保存和加载 model 权重。 (有关更多详细信息,请参阅https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html )
If you use strict=False
in load_state_dict
, you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest. (参见https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict )
因此,您需要在加载预训练的 model 权重时指定 strict 参数。 load_state_dict
可以在这一步调用。 如果必须加载权重的 model 是self.encoder
并且如果可以从刚刚加载的state_dict
中检索到 state_dict,则可以这样做
loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)
有关更多详细信息和教程,请参阅https://pytorch.org/tutorials/beginner/saving_loading_models.html 。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.