简体   繁体   中英

How does a pytorch dataset object know whether it has hit the end when used in a for loop?

I am writing a custom pytorch dataset. In __init__ the dataset object loads a file that contains certain data. But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps). Originally I thought this behavior was controlled by overriding __len__ , but it turned out that modifying __len__ does not help. A simple example is as follows:

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in enumerate(ds):
    print(i)

The output is 0 through 9, while the desired behavior would be 0 through 4.

How does this dataset object know that the enumeration has hit the end when used in a for loop like this? Any other method to achieve a similar effect is also welcome.

You are creating a custom dataloader with Dataset class, while you are enumerating it with for the loop. It is not how it works. For enumerating you have to pass the Dataset to DataLoader class. your code will work great like this,

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
    print(i, ds[i])
#output 
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader 

for i, x in enumerate(dl): # now you can use enumarate
    print(i, x)
#output
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

Further details can be read at this official pytorch tutorial .

You can use torch.utils.data.Subset to get subset of your data

top_five = torch.utils.data.Subset(ds, indices=range(5))  # Get first five items
for i, x in enumerate(top_five):
    print(i)
0
1
2
3
4

enumerate in loop will return item until it getsStopIteration exception.

len(ds)         # Returned modified length
5

# `enumerate` will call `next` method on iterable each time in loop.
#  and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))

print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate in iterable

Output:

tensor([-1.5952, -0.0826])
tensor([-2.2254,  0.2461])
tensor([-0.8268,  0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
     10 print(next(iter_ds))
     11 
---> 12 print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate

StopIteration: 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM