[英]In python3: strange behaviour of list(iterables)
我有一個關於 python 中可迭代對象行為的具體問題。 我的 iterable 是 pytorch 中自定義構建的 Dataset 類:
import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
def __init__(self, X):
self.X = X
def __len__(self):
return len(self.X)
def __getitem__(self, x):
print('***********')
print('getitem x = ', x)
print('###########')
y = self.X[x]
print('getitem y = ', y)
return y
當我初始化該 datasetTest 類的特定實例時,現在會出現奇怪的行為。 根據我作為參數 X 傳遞的數據結構,當我調用 list(datasetTestInstance) 時它的行為會有所不同。 特別是,當傳遞一個 torch.tensor 作為參數時沒有問題,但是當傳遞一個 dict 作為參數時它會拋出一個 KeyError。 這樣做的原因是 list(iterable) 不僅調用了 i=0, ..., len(iterable)-1,而且還調用了 i=0, ..., len(iterable)。 也就是說,它將迭代直到(包括)索引等於可迭代的長度。 顯然,這個索引沒有在任何 python 數據結構中定義,因為最后一個元素的索引總是 len(datastructure)-1 而不是 len(datastructure)。 如果 X 是 torch.tensor 或列表,則不會出現錯誤,即使我認為應該是錯誤。 即使對於索引為 len(datasetTestinstance) 的(不存在的)元素,它仍然會調用 getitem 但它不會計算 y=self.X[len(datasetTestInstance]。有誰知道 pytorch 是否在內部以某種方式優雅地處理它?
當將 dict 作為數據傳遞時,它會在最后一次迭代中拋出錯誤,當 x=len(datasetTestInstance) 時。 這實際上是我猜的預期行為。 但是為什么這只發生在 dict 而不是 list 或 torch.tensor?
if __name__ == "__main__":
a = datasetTest(torch.randn(5,2))
print(len(a))
print('++++++++++++')
for i in range(len(a)):
print(i)
print(a[i])
print('++++++++++++')
print(list(a))
print('++++++++++++')
b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
print(len(b))
print('++++++++++++')
for i in range(len(b)):
print(i)
print(b[i])
print('++++++++++++')
print(list(b))
如果您想更好地理解我所觀察到的內容,可以嘗試該代碼片段。
我的問題是:
1.) 為什么 list(iterable) 迭代直到(包括)len(iterable)? for 循環不會那樣做。
2.) 在 torch.tensor 或作為數據 X 傳遞的列表的情況下:為什么即使在為索引 len(datasetTestInstance) 調用 getitem 方法時它也不會拋出錯誤,因為它實際上應該超出范圍,因為它沒有定義作為張量/列表中的索引? 或者,換句話說,當到達索引 len(datasetTestInstance) 然后進入getitem方法時,究竟會發生什么? 它顯然不再調用'y = self.X[x]'(否則會出現IndexError)但它確實進入了我可以看到的getitem方法,因為它從getitem方法中打印索引x。 那么在那個方法中會發生什么呢? 為什么它的行為取決於是否擁有 torch.tensor/list 或 dict?
這實際上並不是一個 pytorch 特定問題,而是一個通用的 Python 問題。
您正在使用list( iterable )構建一個列表,其中一個可迭代類是實現序列語義的類。
在此處查看__getitem__
對於序列類型的預期行為(最相關的部分以粗體顯示)
object.__getitem__(self, key)
調用以實現
self[key]
評估。 對於序列類型,接受的鍵應該是整數和切片對象。 請注意,負索引的特殊解釋(如果類希望模擬序列類型)取決於__getitem__()
方法。 如果 key 的類型不合適,可能會引發TypeError
; 如果序列的索引集之外的值(在對負值進行任何特殊解釋之后),則應該引發IndexError
。 對於映射類型,如果缺少 key(不在容器中),則應引發 KeyError。注意:
for
循環期望為非法索引引發IndexError
以允許正確檢測序列的結尾。
這里的問題是,對於序列類型,在使用無效索引調用__getitem__
的情況下,python 需要一個IndexError
。 看來list
構造函數依賴於這種行為。 在您的示例中,當X
是 dict 時,嘗試訪問無效鍵會導致__getitem__
引發KeyError
而不是預期的,因此不會被捕獲並導致列表的構建失敗。
根據此信息,您可以執行以下操作
class datasetTest:
def __init__(self):
self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if index < 0 or index >= len(self):
raise IndexError
return self.X[index]
d = datasetTest()
print(list(d))
我不建議在實踐中這樣做,因為它依賴於你的字典X
只包含整數鍵0
, 1
, ..., len(X)-1
這意味着它在大多數情況下最終表現得就像一個列表,所以你可能最好只使用一個列表。
一堆有用的鏈接:
關鍵是列表構造函數使用(可迭代的)參數的__len__ ((如果提供)來計算新的容器長度),然后對其進行迭代(通過迭代器協議)。
您的示例以這種方式工作(迭代所有鍵並未能達到與字典長度相等的鍵),因為一個可怕的巧合(請記住dict支持迭代器協議,這發生在它的鍵上(這是一個序列)):
更改上述 2 個項目符號表示的任何條件,將使實際錯誤更加雄辯。
兩個對象( dict和list (張量))都支持迭代器協議。 為了使事情工作,您應該將它包裝在您的Dataset類中,並稍微調整映射類型之一(使用值而不是鍵)。
代碼( key_func相關部分)有點復雜,但只是易於配置(如果您想更改某些內容 - 出於演示目的)。
代碼00.py :
#!/usr/bin/env python3
import sys
import torch
from torch.utils.data import Dataset
from random import randint
class SimpleDataset(Dataset):
def __init__(self, x):
self.__iter = None
self.x = x
def __len__(self):
print(" __len__()")
return len(self.x)
def __getitem__(self, key):
print(" __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
try:
val = self.x[key]
print(" {0:}".format(val))
return val
except:
print(" exc")
raise #IndexError
def __iter__(self):
print(" __iter__()")
self.__iter = iter(self.x)
return self
def __next__(self):
print(" __next__()")
if self.__iter is None:
raise StopIteration
val = next(self.__iter)
if isinstance(self.x, (dict,)): # Special handling for dictionaries
val = self.x[val]
return val
def key_transformer(int_key):
return str(int_key) # You could `return int_key` to see that it also works on your original example
def dataset_example(inner, key_func=None):
if key_func is None:
key_func = lambda x: x
print("\nInner object: {0:}".format(inner))
sd = SimpleDataset(inner)
print("Dataset length: {0:d}".format(len(sd)))
print("\nIterating (old fashion way):")
for i in range(len(sd)):
print(" {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
print("\nIterating (Python (iterator protocol) way):")
for element in sd:
print(" {0:}".format(element))
print("\nTry building the list:")
l = list(sd)
print(" List: {0:}\n".format(l))
def main():
dict_size = 2
for inner, func in [
(torch.randn(2, 2), None),
({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer), # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
]:
dataset_example(inner, key_func=func)
if __name__ == "__main__":
print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
main()
print("\nDone.")
輸出:
[cfati@CFATI-5510-0:e:\\Work\\Dev\\StackOverflow\\q059091544]> "e:\\Work\\Dev\\VEnvs\\py_064_03.07.03_test0\\Scripts\\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: {'1': 86, '0': 25} __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.
您可能還想檢查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET ( IterableDataset )。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.