[英]How does Python know that it has to read a file line by line when the sequence part of the for loop is a file?
[英]How does a pytorch dataset object know whether it has hit the end when used in a for loop?
我正在編寫一個自定義 pytorch 數據集。 在__init__
中,數據集 object 加載包含特定數據的文件。 但在我的程序中,我只希望訪問部分數據(如果有幫助,可以實現訓練/有效切割)。 最初我認為這種行為是通過覆蓋__len__
來控制的,但事實證明修改__len__
並沒有幫助。 一個簡單的例子如下:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
output 是 0 到 9,而所需的行為是 0 到 4。
這個數據集 object 如何知道在這樣的 for 循環中使用時枚舉已經結束? 也歡迎任何其他實現類似效果的方法。
您正在使用Dataset
class 創建自定義數據加載器,同時使用 for 循環對其進行枚舉。 這不是它的工作方式。 要進行枚舉,您必須將Dataset
傳遞給DataLoader
class。 你的代碼會像這樣很好地工作,
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
print(i, ds[i])
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader
for i, x in enumerate(dl): # now you can use enumarate
print(i, x)
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
更多細節可以在這個官方pytorch 教程中閱讀。
您可以使用torch.utils.data.Subset
獲取數據的子集
top_five = torch.utils.data.Subset(ds, indices=range(5)) # Get first five items
for i, x in enumerate(top_five):
print(i)
0
1
2
3
4
enumerate
in loop 將返回項目,直到它得到StopIteration
異常。
len(ds) # Returned modified length
5
# `enumerate` will call `next` method on iterable each time in loop.
# and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate in iterable
Output:
tensor([-1.5952, -0.0826])
tensor([-2.2254, 0.2461])
tensor([-0.8268, 0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
10 print(next(iter_ds))
11
---> 12 print(next(iter_ds)) #11th time StopIteration exception raised as no item left to iterate
StopIteration:
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.