简体   繁体   English

如何在pytorch模型中加载检查点文件?

[英]How to load a checkpoint file in a pytorch model?

In my pytorch model, I'm initializing my model and optimizer like this. 在我的pytorch模型中,我正在像这样初始化我的模型和优化器。

model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

And here is the path to my checkpoint file. 这是我的检查点文件的路径。

checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth") checkpoint_file = os.path.join(config.save_dir,“ checkpoint.pth”)

To load this checkpoint file, I check and see if the checkpoint file exists and then I load it as well as the model and optimizer. 要加载此检查点文件,请检查并查看该检查点文件是否存在,然后再加载该文件以及模型和优化器。

if os.path.exists(checkpoint_file):
    if config.resume:
        torch.load(checkpoint_file)
        model.load_state_dict(torch.load(checkpoint_file))
        optimizer.load_state_dict(torch.load(checkpoint_file))

Also, here's how I'm saving my model and optimizer. 另外,这就是我保存模型和优化器的方式。

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)

For some reason I keep getting a strange error whenever I run this code. 由于某种原因,每当我运行此代码时,我总是收到一个奇怪的错误。

model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
        Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
        Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"

Does anyone know why I'm getting this error? 有谁知道为什么我会收到此错误?

You saved the model parameters in a dictionary. 您将模型参数保存在字典中。 You're supposed to use the keys, that you used while saving earlier, to load the model checkpoint and state_dict s like this: 您应该使用之前保存时使用的键来加载模型检查点和state_dict如下所示:

if os.path.exists(checkpoint_file):
    if config.resume:
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

You can check the official tutorial on PyTorch website for more info. 您可以在PyTorch网站上查看官方教程以了解更多信息。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何从检查点文件加载经过微调的 pytorch huggingface bert model? - How to load a fine tuned pytorch huggingface bert model from a checkpoint file? 如何在 Pytorch 中保存/加载具有多个损失的模型检查点? - How to save/load a model checkpoint with several losses in Pytorch? 从变压器加载 model 时出现“无法从 pytorch 检查点文件加载权重” - Getting "Unable to load weights from pytorch checkpoint file" when loading model from transformers 无法从 Pytorch-Lightning 中的检查点加载模型 - Unable to load model from checkpoint in Pytorch-Lightning 如何使用pytorch中的checkpoint模型文件来测试CIFAR-10数据集? - How to use checkpoint model file in pytorch to test the CIFAR-10 dataset? 将 pytorch_model.bin 拆分为块后,无法从 pytorch 检查点加载权重 - Unable to load weights from pytorch checkpoint after splitting pytorch_model.bin into chunks 如何在pytorch中加载自定义model - How to load custom model in pytorch 加载 PyTorch NN model 的检查点时的异常 - exceptions when loading the checkpoint of a PyTorch NN model 保存 pytorch model 并加载新文件 - save pytorch model and load in new file 加载张量流检查点作为keras模型 - Load tensorflow checkpoint as keras model
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM