簡體   English   中英

如何在 PyTorch 中保存經過訓練的 model?

[英]How do I save a trained model in PyTorch?

如何在 PyTorch 中保存經過訓練的 model? 我讀過:

  1. torch.save() / torch.load()用於保存/加載可序列化的 object。
  2. model.state_dict() /model.load_state_dict()用於保存/加載 model Z96739E2A693E258。

我在他們的 github repo 上找到了這個頁面,我會把內容貼在這里。


保存模型的推薦方法

序列化和恢復模型有兩種主要方法。

第一個(推薦)只保存和加載模型參數:

torch.save(the_model.state_dict(), PATH)

后來:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二個保存並加載整個模型:

torch.save(the_model, PATH)

后來:

the_model = torch.load(PATH)

但是在這種情況下,序列化數據綁定到特定的類和所使用的確切目錄結構,因此在其他項目中使用時,或者經過一些嚴重的重構后,它可能會以各種方式中斷。

這取決於你想做什么。

案例#1:保存模型以供自己用於推理:您保存模型,恢復它,然后將模型更改為評估模式。 這樣做是因為您通常有BatchNormDropout層,它們在構建時默認處於訓練模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便稍后繼續訓練:如果您需要繼續訓練您將要保存的模型,您需要保存的不僅僅是模型。 您還需要保存優化器的狀態、時期、分數等。您可以這樣做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢復訓練,您可以執行以下操作: state = torch.load(filepath) ,然后恢復每個單獨對象的狀態,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由於您正在恢復訓練,因此在加載時恢復狀態后不要調用model.eval()

案例#3:模型供其他人使用而無法訪問您的代碼:在 Tensorflow 中,您可以創建一個.pb文件來定義模型的架構和權重。 這非常方便,特別是在使用Tensorflow serve 在 Pytorch 中執行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

這種方式仍然不是防彈的,而且由於 pytorch 仍在經歷很多變化,我不推薦它。

pickle Python 庫實現了用於序列化和反序列化 Python 對象的二進制協議。

當你import torch (或者當你使用PyTorch)時,它會為你import pickle ,你不需要直接調用pickle.dump()pickle.load() ,它們是保存和加載對象的方法。

事實上, torch.save()torch.load()會為你包裝pickle.dump()pickle.load()

提到的另一個答案的state_dict只需要多加注意。

什么state_dict我們有內部PyTorch? 實際上有兩個state_dict

PyTorch 模型是torch.nn.Module ,它具有model.parameters()調用以獲取可學習參數(w 和 b)。 這些可學習的參數一旦隨機設置,就會隨着我們的學習而隨着時間的推移而更新。 可學習的參數是第一個state_dict

第二個state_dict是優化器狀態字典。 您還記得優化器用於改進我們的可學習參數。 但是優化器state_dict是固定的。 沒有什么可在那里學習的。

由於state_dict對象是 Python 字典,因此它們可以輕松保存、更新、更改和恢復,從而為 PyTorch 模型和優化器添加了大量模塊化。

讓我們創建一個超級簡單的模型來解釋這一點:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

此代碼將輸出以下內容:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

請注意,這是一個最小模型。 您可以嘗試添加順序堆棧

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

請注意,只有具有可學習參數的層(卷積層、線性層等)和注冊緩沖區(batchnorm 層)在模型的state_dict有條目。

不可學習的東西屬於優化器對象state_dict ,它包含有關優化器狀態的信息,以及使用的超參數。

故事的其余部分是相同的; 在推理階段(這是我們訓練后使用模型的階段)進行預測; 我們確實根據我們學到的參數進行預測。 所以對於推理,我們只需要保存參數model.state_dict()

torch.save(model.state_dict(), filepath)

並在以后使用 model.load_state_dict(torch.load(filepath)) model.eval()

注意:不要忘記最后一行model.eval()這在加載模型后至關重要。

也不要嘗試保存torch.save(model.parameters(), filepath) model.parameters()只是生成器對象。

另一方面, torch.save(model, filepath)保存模型對象本身,但請記住,模型沒有優化器的state_dict 檢查@Jadiel de Armas 的另一個優秀答案以保存優化器的狀態字典。

一個常見的 PyTorch 約定是使用 .pt 或 .pth 文件擴展名保存模型。

保存/加載整個模型

節省:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

加載:

(模型類必須在某處定義)

model.load_state_dict(torch.load(PATH))
model.eval()

如果您想保存模型並想稍后繼續訓練:

單 GPU:保存:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加載:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

多 GPU:保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加載:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

本地保存

您保存模型的方式取決於您以后希望如何訪問它。 如果您可以調用model類的新實例,那么您需要做的就是使用model.state_dict()保存/加載模型的權重:

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

如果你不能因為任何原因(或者更喜歡更簡單的語法),那么你可以用torch.save()保存整個模型(實際上是對定義模型的文件的引用,以及它的 state_dict torch.save()

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

但由於這是對定義模型類的文件位置的引用,除非這些文件也被移植到相同的目錄結構中,否則此代碼不可移植。

保存到雲端 - TorchHub

如果您希望您的模型便於攜帶,您可以使用torch.hub輕松地將其導入。 如果您將適當定義的hubconf.py文件添加到 github 存儲庫,則可以從 PyTorch 中輕松調用該文件,以使用戶能夠使用/不使用權重加載您的模型:

hubconf.py ( github.com/repo_owner/repo_name )

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

加載模型:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)

這幾天一切都寫在官方教程中: https : //pytorch.org/tutorials/beginner/ Saving_loading_models.html

您有幾個關於如何保存和保存內容的選項,所有內容都在該教程中進行了解釋。

pip 安裝 pytorch-lightning

確保您的父模型使用 pl.LightningModule 而不是 nn.Module

使用 pytorch 閃電保存和加載檢查點

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

我總是喜歡使用 Torch7 (.t7) 或 Pickle (.pth, .pt) 來保存 pytorch 模型的權重。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM