简体   繁体   中英

Pytorch Tensors using all RAM

I have a list of tensors, which is too heavy for my RAM. I would like to save them in filesystem and load them when needed

torch.save(single_tensor, 'tensor_<idx>.pt')

If I want to use batches while training, is there an automatic way to load tensors when needed? I was thinking about using TensorDataset and DataLoader , but since now I don't have tensors in a list but in filesystem, how should I build them?

Firstly save the tensors one by one to file with torch.save()

torch.save(tensor, 'path/to/file.pt')

Then this Dataset class allows to load the tensors only when they are really needed:

class EmbedDataset(torch.utils.data.Dataset):
    def __init__(self, first_embed_path, second_embed_path, labels):
        self.first_embed_path = first_embed_path 
        self.second_embed_path = second_embed_path 
        self.labels = labels

        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):

        label = self.labels[i]

        embed = torch.load(os.path.join(self.first_embed_path, str(i) + '.pt'))

        pos = torch.load(os.path.join(self.second_embed_path, str(i) + '.pt'))

        tensor = torch.cat((embed, pos))

        return tensor, label

Here the tensors are named with numbers, eg 1.pt or 1816.pt

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM