[英]How does a pytorch dataset object know whether it has hit the end when used in a for loop?
I am writing a custom pytorch dataset.我正在编写一个自定义 pytorch 数据集。 In __init__
the dataset object loads a file that contains certain data.在__init__
中,数据集 object 加载包含特定数据的文件。 But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps).但在我的程序中,我只希望访问部分数据(如果有帮助,可以实现训练/有效切割)。 Originally I thought this behavior was controlled by overriding __len__
, but it turned out that modifying __len__
does not help.最初我认为这种行为是通过覆盖__len__
来控制的,但事实证明修改__len__
并没有帮助。 A simple example is as follows:一个简单的例子如下:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
The output is 0 through 9, while the desired behavior would be 0 through 4. output 是 0 到 9,而所需的行为是 0 到 4。
How does this dataset object know that the enumeration has hit the end when used in a for loop like this?这个数据集 object 如何知道在这样的 for 循环中使用时枚举已经结束? Any other method to achieve a similar effect is also welcome.也欢迎任何其他实现类似效果的方法。
You are creating a custom dataloader with Dataset
class, while you are enumerating it with for the loop.您正在使用Dataset
class 创建自定义数据加载器,同时使用 for 循环对其进行枚举。 It is not how it works.这不是它的工作方式。 For enumerating you have to pass the Dataset
to DataLoader
class.要进行枚举,您必须将Dataset
传递给DataLoader
class。 your code will work great like this,你的代码会像这样很好地工作,
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
print(i, ds[i])
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader
for i, x in enumerate(dl): # now you can use enumarate
print(i, x)
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
Further details can be read at this official pytorch tutorial .更多细节可以在这个官方pytorch 教程中阅读。
You can use torch.utils.data.Subset
to get subset of your data您可以使用torch.utils.data.Subset
获取数据的子集
top_five = torch.utils.data.Subset(ds, indices=range(5)) # Get first five items
for i, x in enumerate(top_five):
print(i)
0
1
2
3
4
enumerate
in loop will return item until it getsStopIteration
exception. enumerate
in loop 将返回项目,直到它得到StopIteration
异常。
len(ds) # Returned modified length
5
# `enumerate` will call `next` method on iterable each time in loop.
# and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate in iterable
Output: Output:
tensor([-1.5952, -0.0826])
tensor([-2.2254, 0.2461])
tensor([-0.8268, 0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
10 print(next(iter_ds))
11
---> 12 print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate
StopIteration:
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.