I am writing a custom pytorch dataset. In __init__
the dataset object loads a file that contains certain data. But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps). Originally I thought this behavior was controlled by overriding __len__
, but it turned out that modifying __len__
does not help. A simple example is as follows:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
The output is 0 through 9, while the desired behavior would be 0 through 4.
How does this dataset object know that the enumeration has hit the end when used in a for loop like this? Any other method to achieve a similar effect is also welcome.
You are creating a custom dataloader with Dataset
class, while you are enumerating it with for the loop. It is not how it works. For enumerating you have to pass the Dataset
to DataLoader
class. your code will work great like this,
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
print(i, ds[i])
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader
for i, x in enumerate(dl): # now you can use enumarate
print(i, x)
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
Further details can be read at this official pytorch tutorial .
You can use torch.utils.data.Subset
to get subset of your data
top_five = torch.utils.data.Subset(ds, indices=range(5)) # Get first five items
for i, x in enumerate(top_five):
print(i)
0
1
2
3
4
enumerate
in loop will return item until it getsStopIteration
exception.
len(ds) # Returned modified length
5
# `enumerate` will call `next` method on iterable each time in loop.
# and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate in iterable
Output:
tensor([-1.5952, -0.0826])
tensor([-2.2254, 0.2461])
tensor([-0.8268, 0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
10 print(next(iter_ds))
11
---> 12 print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate
StopIteration:
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.