简体   繁体   中英

PyTorch Dataset / Dataloader from random source

I have a source of random (non-deterministic, non-repeatable) data, that I'd like to wrap in Dataset and Dataloader for PyTorch training. How can I do this?

__len__ is not defined, as the source is infinite (with possible repition).
__getitem__ is not defined, as the source is non-deterministic.

When defining a custom dataset class, you'd ordinarily subclass torch.utils.data.Dataset and define __len__() and __getitem__() .

However, for cases where you want sequential but not random access, you can use an iterable-style dataset . To do this, you instead subclass torch.utils.data.IterableDataset and define __iter__() . Whatever is returned by __iter__() should be a proper iterator; it should maintain state (if necessary) and define __next__() to obtain the next item in the sequence. __next__() should raise StopIteration when there's nothing left to read. In your case with an infinite dataset, it never needs to do this.

Here's an example:

import torch

class MyInfiniteIterator:
    def __next__(self):
        return torch.randn(10)

class MyInfiniteDataset(torch.utils.data.IterableDataset):
    def __iter__(self):
        return MyInfiniteIterator()

dataset = MyInfiniteDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)

for batch in dataloader:
    # ... Do some stuff here ...
    # ...

    # if some_condition:
    #     break

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM