简体   繁体   English

如何在 PyTorch 中保存经过训练的 model?

[英]How do I save a trained model in PyTorch?

How do I save a trained model in PyTorch?如何在 PyTorch 中保存经过训练的 model? I have read that:我读过:

  1. torch.save() / torch.load() is for saving/loading a serializable object. torch.save() / torch.load()用于保存/加载可序列化的 object。
  2. model.state_dict() /model.load_state_dict() is for saving/loading model state. model.state_dict() /model.load_state_dict()用于保存/加载 model Z96739E2A693E258。

I've found this page on their github repo, I'll just paste the content here.我在他们的 github repo 上找到了这个页面,我会把内容贴在这里。


Recommended approach for saving a model保存模型的推荐方法

There are two main approaches for serializing and restoring a model.序列化和恢复模型有两种主要方法。

The first (recommended) saves and loads only the model parameters:第一个(推荐)只保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

Then later:后来:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:第二个保存并加载整个模型:

torch.save(the_model, PATH)

Then later:后来:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.但是在这种情况下,序列化数据绑定到特定的类和所使用的确切目录结构,因此在其他项目中使用时,或者经过一些严重的重构后,它可能会以各种方式中断。

It depends on what you want to do.这取决于你想做什么。

Case # 1: Save the model to use it yourself for inference : You save the model, you restore it, and then you change the model to evaluation mode.案例#1:保存模型以供自己用于推理:您保存模型,恢复它,然后将模型更改为评估模式。 This is done because you usually have BatchNorm and Dropout layers that by default are in train mode on construction:这样做是因为您通常有BatchNormDropout层,它们在构建时默认处于训练模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Case # 2: Save model to resume training later : If you need to keep training the model that you are about to save, you need to save more than just the model.案例#2:保存模型以便稍后继续训练:如果您需要继续训练您将要保存的模型,您需要保存的不仅仅是模型。 You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:您还需要保存优化器的状态、时期、分数等。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

To resume training you would do things like: state = torch.load(filepath) , and then, to restore the state of each individual object, something like this:要恢复训练,您可以执行以下操作: state = torch.load(filepath) ,然后恢复每个单独对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Since you are resuming training, DO NOT call model.eval() once you restore the states when loading.由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()

Case # 3: Model to be used by someone else with no access to your code : In Tensorflow you can create a .pb file that defines both the architecture and the weights of the model.案例#3:模型供其他人使用而无法访问您的代码:在 Tensorflow 中,您可以创建一个.pb文件来定义模型的架构和权重。 This is very handy, specially when using Tensorflow serve .这非常方便,特别是在使用Tensorflow serve The equivalent way to do this in Pytorch would be:在 Pytorch 中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn't recommend it.这种方式仍然不是防弹的,而且由于 pytorch 仍在经历很多变化,我不推荐它。

The pickle Python library implements binary protocols for serializing and de-serializing a Python object. pickle Python 库实现了用于序列化和反序列化 Python 对象的二进制协议。

When you import torch (or when you use PyTorch) it will import pickle for you and you don't need to call pickle.dump() and pickle.load() directly, which are the methods to save and to load the object.当你import torch (或者当你使用PyTorch)时,它会为你import pickle ,你不需要直接调用pickle.dump()pickle.load() ,它们是保存和加载对象的方法。

In fact, torch.save() and torch.load() will wrap pickle.dump() and pickle.load() for you.事实上, torch.save()torch.load()会为你包装pickle.dump()pickle.load()

A state_dict the other answer mentioned deserves just a few more notes.提到的另一个答案的state_dict只需要多加注意。

What state_dict do we have inside PyTorch?什么state_dict我们有内部PyTorch? There are actually two state_dict s.实际上有两个state_dict

The PyTorch model is torch.nn.Module which has model.parameters() call to get learnable parameters (w and b). PyTorch 模型是torch.nn.Module ,它具有model.parameters()调用以获取可学习参数(w 和 b)。 These learnable parameters, once randomly set, will update over time as we learn.这些可学习的参数一旦随机设置,就会随着我们的学习而随着时间的推移而更新。 Learnable parameters are the first state_dict .可学习的参数是第一个state_dict

The second state_dict is the optimizer state dict.第二个state_dict是优化器状态字典。 You recall that the optimizer is used to improve our learnable parameters.您还记得优化器用于改进我们的可学习参数。 But the optimizer state_dict is fixed.但是优化器state_dict是固定的。 Nothing to learn there.没有什么可在那里学习的。

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.由于state_dict对象是 Python 字典,因此它们可以轻松保存、更新、更改和恢复,从而为 PyTorch 模型和优化器添加了大量模块化。

Let's create a super simple model to explain this:让我们创建一个超级简单的模型来解释这一点:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

This code will output the following:此代码将输出以下内容:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Note this is a minimal model.请注意,这是一个最小模型。 You may try to add stack of sequential您可以尝试添加顺序堆栈

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model's state_dict .请注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓冲区(batchnorm 层)在模型的state_dict有条目。

Non-learnable things belong to the optimizer object state_dict , which contains information about the optimizer's state, as well as the hyperparameters used.不可学习的东西属于优化器对象state_dict ,它包含有关优化器状态的信息,以及使用的超参数。

The rest of the story is the same;故事的其余部分是相同的; in the inference phase (this is a phase when we use the model after training) for predicting;在推理阶段(这是我们训练后使用模型的阶段)进行预测; we do predict based on the parameters we learned.我们确实根据我们学到的参数进行预测。 So for the inference, we just need to save the parameters model.state_dict() .所以对于推理,我们只需要保存参数model.state_dict()

torch.save(model.state_dict(), filepath)

And to use later model.load_state_dict(torch.load(filepath)) model.eval()并在以后使用 model.load_state_dict(torch.load(filepath)) model.eval()

Note: Don't forget the last line model.eval() this is crucial after loading the model.注意:不要忘记最后一行model.eval()这在加载模型后至关重要。

Also don't try to save torch.save(model.parameters(), filepath) .也不要尝试保存torch.save(model.parameters(), filepath) The model.parameters() is just the generator object. model.parameters()只是生成器对象。

On the other hand, torch.save(model, filepath) saves the model object itself, but keep in mind the model doesn't have the optimizer's state_dict .另一方面, torch.save(model, filepath)保存模型对象本身,但请记住,模型没有优化器的state_dict Check the other excellent answer by @Jadiel de Armas to save the optimizer's state dict.检查@Jadiel de Armas 的另一个优秀答案以保存优化器的状态字典。

A common PyTorch convention is to save models using either a .pt or .pth file extension.一个常见的 PyTorch 约定是使用 .pt 或 .pth 文件扩展名保存模型。

Save/Load Entire Model保存/加载整个模型

Save:节省:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Load:加载:

(Model class must be defined somewhere) (模型类必须在某处定义)

model.load_state_dict(torch.load(PATH))
model.eval()

If you want to save the model and wants to resume the training later:如果您想保存模型并想稍后继续训练:

Single GPU: Save:单 GPU:保存:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Multiple GPU: Save多 GPU:保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

Saving locally 本地保存

How you save your model depends on how you want to access it in the future.您保存模型的方式取决于您以后希望如何访问它。 If you can call a new instance of the model class, then all you need to do is save/load the weights of the model with model.state_dict() :如果您可以调用model类的新实例,那么您需要做的就是使用model.state_dict()保存/加载模型的权重:

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

If you cannot for whatever reason (or prefer the simpler syntax), then you can save the entire model (actually a reference to the file(s) defining the model, along with its state_dict) with torch.save() :如果你不能因为任何原因(或者更喜欢更简单的语法),那么你可以用torch.save()保存整个模型(实际上是对定义模型的文件的引用,以及它的 state_dict torch.save()

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

But since this is a reference to the location of the files defining the model class, this code is not portable unless those files are also ported in the same directory structure.但由于这是对定义模型类的文件位置的引用,除非这些文件也被移植到相同的目录结构中,否则此代码不可移植。

Saving to cloud - TorchHub保存到云端 - TorchHub

If you wish your model to be portable, you can easily allow it to be imported with torch.hub .如果您希望您的模型便于携带,您可以使用torch.hub轻松地将其导入。 If you add an appropriately defined hubconf.py file to a github repo, this can be easily called from within PyTorch to enable users to load your model with/without weights:如果您将适当定义的hubconf.py文件添加到 github 存储库,则可以从 PyTorch 中轻松调用该文件,以使用户能够使用/不使用权重加载您的模型:

hubconf.py ( github.com/repo_owner/repo_name ) hubconf.py ( github.com/repo_owner/repo_name )

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

Loading model:加载模型:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)

These days everything is written in the official tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html这几天一切都写在官方教程中: https : //pytorch.org/tutorials/beginner/ Saving_loading_models.html

You have several options on how to save and what to save and all is explained in that tutorial.您有几个关于如何保存和保存内容的选项,所有内容都在该教程中进行了解释。

pip install pytorch-lightning pip 安装 pytorch-lightning

make sure your parent model uses pl.LightningModule instead of nn.Module确保您的父模型使用 pl.LightningModule 而不是 nn.Module

Saving and loading checkpoints using pytorch lightning 使用 pytorch 闪电保存和加载检查点

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

我总是喜欢使用 Torch7 (.t7) 或 Pickle (.pth, .pt) 来保存 pytorch 模型的权重。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM