简体   繁体   English

在python3中:list(iterables)的奇怪行为

[英]In python3: strange behaviour of list(iterables)

I have a specific question regarding the behaviour of iterables in python.我有一个关于 python 中可迭代对象行为的具体问题。 My iterable is a custom built Dataset class in pytorch:我的 iterable 是 pytorch 中自定义构建的 Dataset 类:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

The weird behaviour now comes about when I initialize a specific instance of that datasetTest class.当我初始化该 datasetTest 类的特定实例时,现在会出现奇怪的行为。 Depending on what data structure I pass as an argument X, it behaves differently when I call list(datasetTestInstance).根据我作为参数 X 传递的数据结构,当我调用 list(datasetTestInstance) 时它的行为会有所不同。 In particular, when passing a torch.tensor as argument there is no problem, however when passing a dict as argument it will throw a KeyError.特别是,当传递一个 torch.tensor 作为参数时没有问题,但是当传递一个 dict 作为参数时它会抛出一个 KeyError。 The reason for this is that list(iterable) not just calls i=0, ..., len(iterable)-1, but it calls i=0, ..., len(iterable).这样做的原因是 list(iterable) 不仅调用了 i=0, ..., len(iterable)-1,而且还调用了 i=0, ..., len(iterable)。 That is, it will iterate until (inclusive) the index equal to the length of the iterable.也就是说,它将迭代直到(包括)索引等于可迭代的长度。 Obviously, this index is not definied in any python datastructure, as the last element has always the index len(datastructure)-1 and not len(datastructure).显然,这个索引没有在任何 python 数据结构中定义,因为最后一个元素的索引总是 len(datastructure)-1 而不是 len(datastructure)。 If X is a torch.tensor or a list, no error will be risen, even though I think the should be an error.如果 X 是 torch.tensor 或列表,则不会出现错误,即使我认为应该是错误。 It will still call getitem even for the (non-existent) element with index len(datasetTestinstance) but it will not compute y=self.X[len(datasetTestInstance]. Does anyone know if pytorch handels this somehow gracefully internally?即使对于索引为 len(datasetTestinstance) 的(不存在的)元素,它仍然会调用 getitem 但它不会计算 y=self.X[len(datasetTestInstance]。有谁知道 pytorch 是否在内部以某种方式优雅地处理它?

When passing a dict as data it will throw an error in the last iteration, when x=len(datasetTestInstance).当将 dict 作为数据传递时,它会在最后一次迭代中抛出错误,当 x=len(datasetTestInstance) 时。 This is actually the expected behaviour I guess.这实际上是我猜的预期行为。 But why does this only happen for a dict and not for a list or torch.tensor?但是为什么这只发生在 dict 而不是 list 或 torch.tensor?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

You could try out that snippet of code if you want to understand better what I have observed.如果您想更好地理解我所观察到的内容,可以尝试该代码片段。

My questions are:我的问题是:

1.) Why does list(iterable) iterate until (including) the len(iterable)? 1.) 为什么 list(iterable) 迭代直到(包括)len(iterable)? A for loop doesnt do that. for 循环不会那样做。

2.) In case of a torch.tensor or a list passed as data X: Why does it not throw an error even when calling the getitem method for the index len(datasetTestInstance) which should actually be out of range since it is not defined as an index in the tensor/list? 2.) 在 torch.tensor 或作为数据 X 传递的列表的情况下:为什么即使在为索引 len(datasetTestInstance) 调用 getitem 方法时它也不会抛出错误,因为它实际上应该超出范围,因为它没有定义作为张量/列表中的索引? Or, in other words, when having reached the index len(datasetTestInstance) and then going into the getitem method, what happens exactly?或者,换句话说,当到达索引 len(datasetTestInstance) 然后进入getitem方法时,究竟会发生什么? It obviously doesnt make the call 'y = self.X[x]' anymore (otherwiese there would be an IndexError) but it DOES enter the getitem method which I can see as it prints the index x from within the getitem method.它显然不再调用'y = self.X[x]'(否则会出现IndexError)但它确实进入了我可以看到的getitem方法,因为它从getitem方法中打印索引x。 So what happens in that method?那么在那个方法中会发生什么呢? And why does it behave different depending on whether having a torch.tensor/list or a dict?为什么它的行为取决于是否拥有 torch.tensor/list 或 dict?

This isn't really a pytorch specific issue, it's a general python question.这实际上并不是一个 pytorch 特定问题,而是一个通用的 Python 问题。

You're constructing a list using list( iterable ) where an iterable class is one which implements sequence semantics .您正在使用list( iterable )构建一个列表,其中一个可迭代类是实现序列语义的类

Take a look here at the expected behavior of __getitem__ for sequence types (most relevant parts are in bold)在此处查看__getitem__对于序列类型的预期行为(最相关的部分以粗体显示)

object.__getitem__(self, key)

Called to implement evaluation of self[key] .调用以实现self[key]评估。 For sequence types, the accepted keys should be integers and slice objects.对于序列类型,接受的键应该是整数和切片对象。 Note that the special interpretation of negative indexes (if the class wishes to emulate a sequence type) is up to the __getitem__() method.请注意,负索引的特殊解释(如果类希望模拟序列类型)取决于__getitem__()方法。 If key is of an inappropriate type, TypeError may be raised;如果 key 的类型不合适,可能会引发TypeError if of a value outside the set of indexes for the sequence (after any special interpretation of negative values), IndexError should be raised.如果序列的索引集之外的值(在对负值进行任何特殊解释之后),则应该引发IndexError For mapping types, if key is missing (not in the container), KeyError should be raised.对于映射类型,如果缺少 key(不在容器中),则应引发 KeyError。

Note: for loops expect that an IndexError will be raised for illegal indexes to allow proper detection of the end of the sequence.注意: for循环期望为非法索引引发IndexError以允许正确检测序列的结尾。

The problem here is that for sequence types python expects an IndexError in the case where __getitem__ is invoked with an invalid index.这里的问题是,对于序列类型,在使用无效索引调用__getitem__的情况下,python 需要一个IndexError It appears the list constructor relies on this behavior.看来list构造函数依赖于这种行为。 In your example when X is a dict, attempting to access an invalid key causes __getitem__ to raise KeyError instead which isn't expected so isn't caught and causes the construction of the list to fail.在您的示例中,当X是 dict 时,尝试访问无效键会导致__getitem__引发KeyError而不是预期的,因此不会被捕获并导致列表的构建失败。


Based on this information you could do something like the following根据此信息,您可以执行以下操作

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

I can't recommend doing this in practice since it relies on your dictionary X containing only the integer keys 0 , 1 , ..., len(X)-1 which means it end up behaving just like a list for most cases, so you're probably better off just using a list.我不建议在实践中这样做,因为它依赖于你的字典X只包含整数键0 , 1 , ..., len(X)-1这意味着它在大多数情况下最终表现得就像一个列表,所以你可能最好只使用一个列表。

A bunch of useful links:一堆有用的链接:

  1. [Python 3.Docs]: Data model - Emulating container types [Python 3.Docs]:数据模型 - 模拟容器类型
  2. [Python 3.Docs]: Built-in Types - Iterator Types [Python 3.Docs]:内置类型 - 迭代器类型
  3. [Python 3.Docs]: Built-in Functions - iter ( object[, sentinel] ) [Python 3.Docs]: 内置函数 - iter ( object[, sentinel] )
  4. [SO]: Why does list ask about __len__? [SO]:为什么 list 会询问 __len__? (all answers) (所有答案)

The key point is that the list constructor uses the (iterable) argument's __len__ ((if provided) to calculate the new container length), but then iterates over it (via iterator protocol).关键是列表构造函数使用(可迭代的)参数的__len__ ((如果提供)来计算新的容器长度),然后对其进行迭代(通过迭代器协议)。

Your example worked in that manner (iterated all keys and failed to the one equal to the dictionary length) because of a terrible coincidence (remember that dict supports iterator protocol, and that happens over its keys (which is a sequence)):您的示例以这种方式工作(迭代所有键并未能达到与字典长度相等的键),因为一个可怕的巧合(请记住dict支持迭代器协议,这发生在它的键上(这是一个序列)):

  • Your dictionary only had int keys (and more)你的字典只有int(以及更多)
  • Their values are the same as their indexes (in the sequence)它们的值与它们的索引相同(按顺序)

Changing any condition expressed by the above 2 bullets, would make the actual error more eloquent.更改上述 2 个项目符号表示的任何条件,将使实际错误更加雄辩。

Both objects ( dict and list (of tensor s)) support iterator protocol.两个对象( dictlist张量))都支持迭代器协议。 In order to make things work, you should wrap it in your Dataset class, and tweak a bit the one of the mapping type (to work with values instead of keys).为了使事情工作,您应该将它包装在您的Dataset类中,并稍微调整映射类型之一(使用值而不是键)。
The code ( key_func related parts) is a bit complex, but only to be easily configurable (if you want to change something - for demo purposes).代码( key_func相关部分)有点复杂,但只是易于配置(如果您想更改某些内容 - 出于演示目的)。

code00.py :代码00.py

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

Output :输出

 [cfati@CFATI-5510-0:e:\\Work\\Dev\\StackOverflow\\q059091544]> "e:\\Work\\Dev\\VEnvs\\py_064_03.07.03_test0\\Scripts\\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: {'1': 86, '0': 25} __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.

You might also want to check [PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET ( IterableDataset ).您可能还想检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET ( IterableDataset )。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM