簡體   English   中英

Pytorch 張量使用所有 RAM

[英]Pytorch Tensors using all RAM

我有一個張量列表,這對我的 RAM 來說太重了。 我想將它們保存在文件系統中並在需要時加載它們

torch.save(single_tensor, 'tensor_<idx>.pt')

如果我想在訓練時使用批處理,是否有一種在需要時自動加載張量的方法? 我正在考慮使用TensorDatasetDataLoader ,但是現在我在列表中沒有張量,但在文件系統中,我應該如何構建它們?

首先使用torch.save()將張量一一保存到文件中

torch.save(tensor, 'path/to/file.pt')

然后這個Dataset class 允許僅在真正需要時加載張量:

class EmbedDataset(torch.utils.data.Dataset):
    def __init__(self, first_embed_path, second_embed_path, labels):
        self.first_embed_path = first_embed_path 
        self.second_embed_path = second_embed_path 
        self.labels = labels

        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):

        label = self.labels[i]

        embed = torch.load(os.path.join(self.first_embed_path, str(i) + '.pt'))

        pos = torch.load(os.path.join(self.second_embed_path, str(i) + '.pt'))

        tensor = torch.cat((embed, pos))

        return tensor, label

這里張量用數字命名,例如1.pt1816.pt

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM