简体   繁体   中英

PyTorch's custom dataset class should inherit from torch.utils.data.Dataset or not - why both work?

I was studying the PyTorch's Dataset class. From what I knew beforehand, we need to inherit from torch.utils.data.Dataset everytime we create a CustomDataset class of our own; and further we need to override the __len__ and __getitem__ methods as per need.

But, I got to know that it isn't always necessary to inherit and we could go on to create our CustomDataset class with __len__ and __getitem__ methods without inheriting from torch.utils.data.Dataset and even then the behaviour of an instance of custom dataset remains pretty much same (I tested it myself).

That to say, len(cust_data) would return the length of the dataset we pass while creating our cust_data instance, and we could even index cust_data like cust_data[0] and it would return what's returned by __getitem__ method in our CustomDataset class.

My questions are -

  1. What is the need of inheriting when we are just as fine without inheriting - and if we are not, what functionality do we miss on if we do not inherit? When is inheriting recommended and when is it not? (While the official docs recommend to inherit, always)

  2. When not inheriting, how did the instance know it needs to call the __getitem__ method when it is indexed?

Any answers appreciated.

  1. Take a look at the source code for torch.utils.data.Dataset - it is an abstract class, which guarantees that every class inherting it must implement __getitem__ . In other words, you don't "need" to inherit Dataset : as long as __getitem__ is properly implemented your dataset class will work fine. The reason of why doing so has become common practice is that it indicates to a third-party (eg some other code that uses your dataset class, someone else reading your code) that the class in question has __getitem__ implemented. It provides a common interface for PyTorch datasets.

  2. Executing someClass[i] will automatically call someClass.__getitem__ with parameter i (and will throw an error if __getitem__ is not implemented). This is a Python built-in feature and has nothing to do with whichever base class you are inheriting. You can Google "dunder methods" to learn more about these special behaviors.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM