简体   繁体   中英

In python3: strange behaviour of list(iterables)

I have a specific question regarding the behaviour of iterables in python. My iterable is a custom built Dataset class in pytorch:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

The weird behaviour now comes about when I initialize a specific instance of that datasetTest class. Depending on what data structure I pass as an argument X, it behaves differently when I call list(datasetTestInstance). In particular, when passing a torch.tensor as argument there is no problem, however when passing a dict as argument it will throw a KeyError. The reason for this is that list(iterable) not just calls i=0, ..., len(iterable)-1, but it calls i=0, ..., len(iterable). That is, it will iterate until (inclusive) the index equal to the length of the iterable. Obviously, this index is not definied in any python datastructure, as the last element has always the index len(datastructure)-1 and not len(datastructure). If X is a torch.tensor or a list, no error will be risen, even though I think the should be an error. It will still call getitem even for the (non-existent) element with index len(datasetTestinstance) but it will not compute y=self.X[len(datasetTestInstance]. Does anyone know if pytorch handels this somehow gracefully internally?

When passing a dict as data it will throw an error in the last iteration, when x=len(datasetTestInstance). This is actually the expected behaviour I guess. But why does this only happen for a dict and not for a list or torch.tensor?

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

You could try out that snippet of code if you want to understand better what I have observed.

My questions are:

1.) Why does list(iterable) iterate until (including) the len(iterable)? A for loop doesnt do that.

2.) In case of a torch.tensor or a list passed as data X: Why does it not throw an error even when calling the getitem method for the index len(datasetTestInstance) which should actually be out of range since it is not defined as an index in the tensor/list? Or, in other words, when having reached the index len(datasetTestInstance) and then going into the getitem method, what happens exactly? It obviously doesnt make the call 'y = self.X[x]' anymore (otherwiese there would be an IndexError) but it DOES enter the getitem method which I can see as it prints the index x from within the getitem method. So what happens in that method? And why does it behave different depending on whether having a torch.tensor/list or a dict?

This isn't really a pytorch specific issue, it's a general python question.

You're constructing a list using list( iterable ) where an iterable class is one which implements sequence semantics .

Take a look here at the expected behavior of __getitem__ for sequence types (most relevant parts are in bold)

object.__getitem__(self, key)

Called to implement evaluation of self[key] . For sequence types, the accepted keys should be integers and slice objects. Note that the special interpretation of negative indexes (if the class wishes to emulate a sequence type) is up to the __getitem__() method. If key is of an inappropriate type, TypeError may be raised; if of a value outside the set of indexes for the sequence (after any special interpretation of negative values), IndexError should be raised. For mapping types, if key is missing (not in the container), KeyError should be raised.

Note: for loops expect that an IndexError will be raised for illegal indexes to allow proper detection of the end of the sequence.

The problem here is that for sequence types python expects an IndexError in the case where __getitem__ is invoked with an invalid index. It appears the list constructor relies on this behavior. In your example when X is a dict, attempting to access an invalid key causes __getitem__ to raise KeyError instead which isn't expected so isn't caught and causes the construction of the list to fail.


Based on this information you could do something like the following

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

I can't recommend doing this in practice since it relies on your dictionary X containing only the integer keys 0 , 1 , ..., len(X)-1 which means it end up behaving just like a list for most cases, so you're probably better off just using a list.

A bunch of useful links:

  1. [Python 3.Docs]: Data model - Emulating container types
  2. [Python 3.Docs]: Built-in Types - Iterator Types
  3. [Python 3.Docs]: Built-in Functions - iter ( object[, sentinel] )
  4. [SO]: Why does list ask about __len__? (all answers)

The key point is that the list constructor uses the (iterable) argument's __len__ ((if provided) to calculate the new container length), but then iterates over it (via iterator protocol).

Your example worked in that manner (iterated all keys and failed to the one equal to the dictionary length) because of a terrible coincidence (remember that dict supports iterator protocol, and that happens over its keys (which is a sequence)):

  • Your dictionary only had int keys (and more)
  • Their values are the same as their indexes (in the sequence)

Changing any condition expressed by the above 2 bullets, would make the actual error more eloquent.

Both objects ( dict and list (of tensor s)) support iterator protocol. In order to make things work, you should wrap it in your Dataset class, and tweak a bit the one of the mapping type (to work with values instead of keys).
The code ( key_func related parts) is a bit complex, but only to be easily configurable (if you want to change something - for demo purposes).

code00.py :

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

Output :

 [cfati@CFATI-5510-0:e:\\Work\\Dev\\StackOverflow\\q059091544]> "e:\\Work\\Dev\\VEnvs\\py_064_03.07.03_test0\\Scripts\\python.exe" code00.py Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32 Inner object: tensor([[ 0.6626, 0.1107], [-0.1118, 0.6177]]) __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(int)) tensor([0.6626, 0.1107]) 0: tensor([0.6626, 0.1107]) __getitem__(1(int)) tensor([-0.1118, 0.6177]) 1: tensor([-0.1118, 0.6177]) Iterating (Python (iterator protocol) way): __iter__() __next__() tensor([0.6626, 0.1107]) __next__() tensor([-0.1118, 0.6177]) __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [tensor([0.6626, 0.1107]), tensor([-0.1118, 0.6177])] Inner object: {'1': 86, '0': 25} __len__() Dataset length: 2 Iterating (old fashion way): __len__() __getitem__(0(str)) 25 0: 25 __getitem__(1(str)) 86 1: 86 Iterating (Python (iterator protocol) way): __iter__() __next__() 86 __next__() 25 __next__() Try building the list: __iter__() __len__() __next__() __next__() __next__() List: [86, 25] Done.

You might also want to check [PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASET ( IterableDataset ).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM