简体   繁体   中英

How to use a Pytorch DataLoader for a dataset with multiple labels

I'm wondering how to create a DataLoader that supports multiple types of labels in Pytorch. How do I do this?

You can return a dict of labels for each item in the dataset, and DataLoader is smart enough to collate them for you. ie if you provide a dict for each item, the DataLoader will return a dict , where the keys are the label types. Accessing a key of that label type returns a collated tensor of that label type.

See below:

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class M(Dataset):
    def __init__(self):
        super().__init__()
        self.data = np.random.randn(20, 2)
        print(self.data)

    def __getitem__(self, i):
        return self.data[i], {'label_1':self.data[i], 'label_2':self.data[i]}

    def __len__(self):
        return len(self.data)

ds = M()
dl = DataLoader(ds, batch_size=6)

for x, y in dl:
    print(x, '\n', y)
    print(type(x), type(y))
[[-0.33029911  0.36632142]
 [-0.25303721 -0.11872778]
 [-0.35955625 -1.41633132]
 [ 1.28814629  0.38238357]
 [ 0.72908184 -0.09222787]
 [-0.01777293 -1.81824167]
 [-0.85346074 -1.0319562 ]
 [-0.4144832   0.12125039]
 [-1.29546792 -1.56314292]
 [ 1.22566887 -0.71523568]]
tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64) 
 {'item_1': tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64), 'item_2': tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64)}
<class 'torch.Tensor'> <class 'dict'>
...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM