簡體   English   中英

一個pickle任意pytorch型號使用lambda功能如何?

[英]How does one pickle arbitrary pytorch models that use lambda functions?

我目前有一個神經網絡模塊:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

我正在嘗試檢查它,但是因為 pytorch 使用state_dict保存,這意味着我無法保存我實際使用的 lambda 函數,如果我使用torch.save進行檢查點,而不會出現問題,我想從字面上保存一切。稍后在 GPU 上進行訓練。 我目前正在使用這個:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

目前,當我檢查它並保存它時,它不會引發任何錯誤。

我擔心當我訓練它時,即使沒有訓練異常/錯誤或者可能發生意外的事情(例如在集群中的磁盤上奇怪的保存等誰知道),也可能會出現一個微妙的錯誤。

這對 pytorch 類/nn 模型安全嗎? 特別是如果我們想恢復使用 GPU 進行訓練?

交叉貼:

我是dill作者。 我使用dill (和klepto )在 lambda 函數中保存包含經過訓練的 ANN 的類。 我傾向於使用mysticsklearn的組合,所以我不能直接與pytorch ,但我可以假設它的工作原理相同。 您必須小心的地方是,如果您有一個 lambda,其中包含指向 lambda 外部的 object 的指針...例如y = 4; f = lambda x: x+y y = 4; f = lambda x: x+y 這似乎很明顯,但是dill會腌制 lambda,並且取決於代碼的 rest 和序列化變體,可能不會序列化y的值。 所以,我見過很多情況,人們在一些 function(或 lambda,或類)中序列化一個訓練有素的估計器,然后當他們從 C 序列化恢復 ZC1C425268E68385D1AB5074 時結果不是“正確的”。 最重要的原因是因為 function 沒有被封裝,所以 function 產生正確結果所需的所有對象都存儲在泡菜中。 但是,即使在這種情況下,您也可以獲得“正確”的結果,但是您只需要創建與腌制估算器時相同的環境(即,它依賴於周圍命名空間中的所有相同值)。 要點應該是,盡量確保 function 中使用的所有變量都在 function 中定義。 這是我最近開始使用自己的 class 的一部分(應該在mystic的下一個版本中):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

請注意,當調用 function 時,它使用的所有內容(包括np )都在周圍的命名空間中定義。 只要pytorch估計器按預期序列化(沒有外部引用),那么如果您遵循上述指南,您應該沒問題。

Yes, I think it is safe to use dill to pickle lambda functions etc. I have been using torch.save with dill to save state dict and have had no problems resuming training over GPU as well as CPU unless the model class was changed. Even if the model class was changed (adding/deleting some parameters), I could load state dict, modify it, and load to the model.

Also, usually, people don't save the model objects but only state dicts ie parameter values to resume the training along with hyperparameters/model arguments to get the same model object later.

Saving model object can be sometimes problematic as changes to model class (code) can make the saved object useless. If you don't plan on changing your model class/code at all and hence the model object won't be changed then maybe saving objects can work well but generally, it is not recommended to pickle module object.

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM