简体   繁体   中英

How to get a specific sample from pytorch DataLoader?

In Pytorch, is there any way of loading a specific single sample using the torch.utils.data.DataLoader class? I'd like to do some testing with it.

The tutorial uses

trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))

to fetch a random batch of samples. Is there are way, using DataLoader , to get a specific sample?

Cheers

  • Turn off the shuffle in DataLoader
  • Use batch_size to calculate the batch in which the desired sample you are looking for falls in
  • Iterate to the desired batch

Code

import torch 
import numpy as np
import itertools

X= np.arange(100)
batch_size = 2

dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size, shuffle=False)
sample_at = 5
k = int(np.floor(sample_at/batch_size))

my_sample = next(itertools.islice(dataloader, k, None))
print (my_sample)

Output:

tensor([4, 5])

if you want to get a specific signle sample from your dataset you can
you should check Subset class.( https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset ) something like this:

indices =  [0,1,2]  # select your indices here as a list  
subset = torch.utils.data.Subset(train_set, indices)
trainloader = DataLoader(subset , batch_size =  16  , shuffle =False) #set shuffle to False 

for image , label in trainloader:
   print(image.size() , '\t' , label.size())
   print(image[0], '\t' , label[0]) # index the specific sample 

here is a useful link if you want to learn more about the Pytorch data loading utility ( https://pytorch.org/docs/stable/data.html )

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM