[英]Loading a model with pytorch
我在图像分类器项目中加载我的模型时遇到问题。 首先,我保存了它:
model.class_to_idx = train_data.class_to_idx
checkpoint = {'arch': 'vgg19',
'learn_rate': learn_rate,
'epochs': epochs,
'state_dict': model.state_dict(),
'class_to_idx': model.class_to_idx,
'optimizer': optimizer.state_dict(),
'input_size': 25088,
'output_size': 102,
'momentum': momentum,
'batch_size':64,
'classifier' : classifier
}
torch.save(checkpoint, 'checkpoint.pth')
然后我尝试加载我保存的项目:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
learn_rate = checkpoint['learn_rate']
optimizer.load_state_dict(checkpoint['optimizer'])
model = models.vgg16(pretrained=True)
model.epochs = checkpoint['epochs']
model.load_state_dict(checkpoint['state_dict'])
model.class_to_idx = checkpoint['class_to_idx']
model.classifier = checkpoint['classifier']
return learn_rate, optimizer, model
learn_rate, optimizer, model = load_checkpoint('checkpoint.pth')
当我尝试加载时出现错误:
<ipython-input-75-5bd1aa042c7f> in load_checkpoint(filepath)
9 model = models.vgg16(pretrained=True)
10 model.epochs = checkpoint['epochs']
---> 11 model.load_state_dict(checkpoint['state_dict'])
12 model.class_to_idx = checkpoint['class_to_idx']
13 model.classifier = checkpoint['classifier']
RuntimeError: Error(s) in loading state_dict for VGG:
Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias".
Unexpected key(s) in state_dict: "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias".
这似乎是分类器问题。 有谁知道发生了什么?
jodag 的评论指出了问题的核心。 如果fc1 fc2对应classifier.0classifier.3,classifier.6,你可以调整字典来链接它们。 将权重加载到模型时,请确保添加选项 strict=False。
您将需要为分类器重新训练您的模型 - 因为您的状态字典错过了 3 层的权重,但有 2 个未使用的层权重 - 但它应该会很快收敛(根据个人经验)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.