简体   繁体   English

pytorch 数据集 object 在 for 循环中使用时如何知道它是否已经结束?

[英]How does a pytorch dataset object know whether it has hit the end when used in a for loop?

I am writing a custom pytorch dataset.我正在编写一个自定义 pytorch 数据集。 In __init__ the dataset object loads a file that contains certain data.__init__中,数据集 object 加载包含特定数据的文件。 But in my program I only wish to access part of the data (to achieve train/valid cut, if it helps).但在我的程序中,我只希望访问部分数据(如果有帮助,可以实现训练/有效切割)。 Originally I thought this behavior was controlled by overriding __len__ , but it turned out that modifying __len__ does not help.最初我认为这种行为是通过覆盖__len__来控制的,但事实证明修改__len__并没有帮助。 A simple example is as follows:一个简单的例子如下:

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in enumerate(ds):
    print(i)

The output is 0 through 9, while the desired behavior would be 0 through 4. output 是 0 到 9,而所需的行为是 0 到 4。

How does this dataset object know that the enumeration has hit the end when used in a for loop like this?这个数据集 object 如何知道在这样的 for 循环中使用时枚举已经结束? Any other method to achieve a similar effect is also welcome.也欢迎任何其他实现类似效果的方法。

You are creating a custom dataloader with Dataset class, while you are enumerating it with for the loop.您正在使用Dataset class 创建自定义数据加载器,同时使用 for 循环对其进行枚举。 It is not how it works.这不是它的工作方式。 For enumerating you have to pass the Dataset to DataLoader class.要进行枚举,您必须将Dataset传递给DataLoader class。 your code will work great like this,你的代码会像这样很好地工作,

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
    print(i, ds[i])
#output 
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader 

for i, x in enumerate(dl): # now you can use enumarate
    print(i, x)
#output
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

Further details can be read at this official pytorch tutorial .更多细节可以在这个官方pytorch 教程中阅读。

You can use torch.utils.data.Subset to get subset of your data您可以使用torch.utils.data.Subset获取数据的子集

top_five = torch.utils.data.Subset(ds, indices=range(5))  # Get first five items
for i, x in enumerate(top_five):
    print(i)
0
1
2
3
4

enumerate in loop will return item until it getsStopIteration exception. enumerate in loop 将返回项目,直到它得到StopIteration异常。

len(ds)         # Returned modified length
5

# `enumerate` will call `next` method on iterable each time in loop.
#  and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))

print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate in iterable

Output: Output:

tensor([-1.5952, -0.0826])
tensor([-2.2254,  0.2461])
tensor([-0.8268,  0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])

---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
     10 print(next(iter_ds))
     11 
---> 12 print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate

StopIteration: 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 当for循环的序列部分是文件时,Python如何知道它必须逐行读取文件? - How does Python know that it has to read a file line by line when the sequence part of the for loop is a file? PyTorch DataLoader 如何与 PyTorch 数据集交互以转换批次? - How does PyTorch DataLoader interact with a PyTorch dataset to transform batches? 如何知道某人何时返回HIT? - How to know when someone returns HIT? 如何验证一个函数是否被 pytest 命中 - How to verify whether a function has been hit with pytest 如何在pytorch中对数据集进行排序 - How to sort a dataset in pytorch pytorch 训练循环以 ''int' 结束 object 没有属性 'size' 异常 - pytorch training loop ends with ''int' object has no attribute 'size' exception 当继续命中时,是否有一种 Pythonic 方式在“for”循环中激活“else”? - Is there a Pythonic way of activating 'else' in a 'for' loop, when a continue has hit? python socket.recv() 方法如何知道已经到达消息结尾? - How does the python socket.recv() method know that the end of the message has been reached? 为什么在循环中使用 print() 中的 end=' ' 参数会打印两次? - Why does the end=' ' argument in print() get printed twice when used in a loop? for循环如何知道递增 - How does a for loop know to increment
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM