簡體   English   中英

PyTorch 不能腌制 lambda

[英]PyTorch can't pickle lambda

我有一個使用自定義LambdaLayer的 model,如下所示:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)


class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

model 在 CPU 或 1 GPU 上運行完美。 但是,當使用 PyTorch Lightning 超過 2 個 GPU 運行它時,會發生此錯誤:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

在這里使用 lambda function 的目的是給定inputs張量,我只想將inputs[:, start:end]傳遞給embedding層。

我的問題:

  • 在這種情況下,是否有替代使用 lambda 的方法?
  • 如果沒有,應該怎么做才能讓 lambda function 在這種情況下工作?

所以問題不在於 lambda function 本身,而是 pickle 不適用於不僅僅是模塊級函數的函數(pickle 處理函數的方式就像對某些模塊級名稱的引用)。 因此,不幸的是,如果您需要捕獲 arguments 的startend ,您將無法使用閉包,您通常只需要以下內容:

def function_maker(start, end):
    def function(x):
        return x[:, start:end]
    return function

但這會讓你回到你開始的地方,就酸洗問題而言。

因此,請嘗試以下操作:

class Slicer:
    def __init__(self, start, end):
        self.start = start
        self.end = end
    def __call__(self, x):
        return x[:, self.start:self.end])

然后你可以使用:

LambdaLayer(Slicer(start, end))

我不熟悉 PyTorch,雖然它不提供使用不同序列化后端的能力,但我很驚訝。 例如, pathos/ dill項目可以腌制任意函數,並且通常更容易使用它。 但我相信以上應該可以解決問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM