[英]AttributeError: Can't pickle local object 'pre_datasets.<locals>.<lambda>' when implementing Pytorch framework
[英]PyTorch can't pickle lambda
我有一個使用自定義LambdaLayer
的 model,如下所示:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
model 在 CPU 或 1 GPU 上運行完美。 但是,當使用 PyTorch Lightning 超過 2 個 GPU 運行它時,會發生此錯誤:
AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'
在這里使用 lambda function 的目的是給定inputs
張量,我只想將inputs[:, start:end]
傳遞給embedding
層。
我的問題:
所以問題不在於 lambda function 本身,而是 pickle 不適用於不僅僅是模塊級函數的函數(pickle 處理函數的方式就像對某些模塊級名稱的引用)。 因此,不幸的是,如果您需要捕獲 arguments 的start
和end
,您將無法使用閉包,您通常只需要以下內容:
def function_maker(start, end):
def function(x):
return x[:, start:end]
return function
但這會讓你回到你開始的地方,就酸洗問題而言。
因此,請嘗試以下操作:
class Slicer:
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, x):
return x[:, self.start:self.end])
然后你可以使用:
LambdaLayer(Slicer(start, end))
我不熟悉 PyTorch,雖然它不提供使用不同序列化后端的能力,但我很驚訝。 例如, pathos/ dill項目可以腌制任意函數,並且通常更容易使用它。 但我相信以上應該可以解決問題。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.