簡體   English   中英

在python3中:list(iterables)的奇怪行為

[英]In python3: strange behaviour of list(iterables)

我有一個關於 python 中可迭代對象行為的具體問題。 我的 iterable 是 pytorch 中自定義構建的 Dataset 類:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

當我初始化該 datasetTest 類的特定實例時,現在會出現奇怪的行為。 根據我作為參數 X 傳遞的數據結構,當我調用 list(datasetTestInstance) 時它的行為會有所不同。 特別是,當傳遞一個 torch.tensor 作為參數時沒有問題,但是當傳遞一個 dict 作為參數時它會拋出一個 KeyError。 這樣做的原因是 list(iterable) 不僅調用了 i=0, ..., len(iterable)-1,而且還調用了 i=0, ..., len(iterable)。 也就是說,它將迭代直到(包括)索引等於可迭代的長度。 顯然,這個索引沒有在任何 python 數據結構中定義,因為最后一個元素的索引總是 len(datastructure)-1 而不是 len(datastructure)。 如果 X 是 torch.tensor 或列表,則不會出現錯誤,即使我認為應該是錯誤。 即使對於索引為 len(datasetTestinstance) 的(不存在的)元素,它仍然會調用 getitem 但它不會計算 y=self.X[len(datasetTestInstance]。有誰知道 pytorch 是否在內部以某種方式優雅地處理它?

當將 dict 作為數據傳遞時,它會在最后一次迭代中拋出錯誤,當 x=len(datasetTestInstance) 時。 這實際上是我猜的預期行為。 但是為什么這只發生在 dict 而不是 list 或 torch.tensor?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

如果您想更好地理解我所觀察到的內容,可以嘗試該代碼片段。

我的問題是:

1.) 為什么 list(iterable) 迭代直到(包括)len(iterable)? for 循環不會那樣做。

2.) 在 torch.tensor 或作為數據 X 傳遞的列表的情況下:為什么即使在為索引 len(datasetTestInstance) 調用 getitem 方法時它也不會拋出錯誤,因為它實際上應該超出范圍,因為它沒有定義作為張量/列表中的索引? 或者,換句話說,當到達索引 len(datasetTestInstance) 然后進入getitem方法時,究竟會發生什么? 它顯然不再調用'y = self.X[x]'(否則會出現IndexError)但它確實進入了我可以看到的getitem方法,因為它從getitem方法中打印索引x。 那么在那個方法中會發生什么呢? 為什么它的行為取決於是否擁有 torch.tensor/list 或 dict?

這實際上並不是一個 pytorch 特定問題,而是一個通用的 Python 問題。

您正在使用list( iterable )構建一個列表,其中一個可迭代類是實現序列語義的類

在此處查看__getitem__對於序列類型的預期行為(最相關的部分以粗體顯示)

object.__getitem__(self, key)

調用以實現self[key]評估。 對於序列類型,接受的鍵應該是整數和切片對象。 請注意,負索引的特殊解釋(如果類希望模擬序列類型)取決於__getitem__()方法。 如果 key 的類型不合適,可能會引發TypeError 如果序列的索引集之外的值(在對負值進行任何特殊解釋之后),則應該引發IndexError 對於映射類型,如果缺少 key(不在容器中),則應引發 KeyError。

注意: for循環期望為非法索引引發IndexError以允許正確檢測序列的結尾。

這里的問題是,對於序列類型,在使用無效索引調用__getitem__的情況下,python 需要一個IndexError 看來list構造函數依賴於這種行為。 在您的示例中,當X是 dict 時,嘗試訪問無效鍵會導致__getitem__引發KeyError而不是預期的,因此不會被捕獲並導致列表的構建失敗。


根據此信息,您可以執行以下操作

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

我不建議在實踐中這樣做,因為它依賴於你的字典X只包含整數鍵0 , 1 , ..., len(X)-1這意味着它在大多數情況下最終表現得就像一個列表,所以你可能最好只使用一個列表。

一堆有用的鏈接:

  1. [Python 3.Docs]:數據模型 - 模擬容器類型
  2. [Python 3.Docs]:內置類型 - 迭代器類型
  3. [Python 3.Docs]: 內置函數 - iter ( object[, sentinel] )
  4. [SO]:為什么 list 會詢問 __len__? (所有答案)

關鍵是列表構造函數使用(可迭代的)參數的__len__ ((如果提供)來計算新的容器長度),然后對其進行迭代(通過迭代器協議)。

您的示例以這種方式工作(迭代所有鍵並未能達到與字典長度相等的鍵),因為一個可怕的巧合(請記住dict支持迭代器協議,這發生在它的鍵上(這是一個序列)):

  • 你的字典只有int(以及更多)
  • 它們的值與它們的索引相同(按順序)

更改上述 2 個項目符號表示的任何條件,將使實際錯誤更加雄辯。

兩個對象( dictlist張量))都支持迭代器協議。 為了使事情工作,您應該將它包裝在您的Dataset類中,並稍微調整映射類型之一(使用值而不是鍵)。
代碼( key_func相關部分)有點復雜,但只是易於配置(如果您想更改某些內容 - 出於演示目的)。

代碼00.py

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

輸出

 [cfati@CFATI-5510-0:e:\\Work\\Dev\\StackOverflow\\q059091544]> "e:\\Work\\Dev\\VEnvs\\py_064_03.07.03_test0\\Scripts\\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: {'1': 86, '0': 25} __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.

您可能還想檢查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET ( IterableDataset )。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM