简体   繁体   中英

How to sort a dataset in pytorch

I would like to sort my dataset by the numerical values in the labels.

Is there a function from pytorch to handle this efficiently?

my dataset type() is in this from:

 <class 'torchvision.datasets.mnist.MNIST'>

There is no generic way to do this efficiently, as the dataset class is only implements a __getitem__ and __len__ method, and don't necessarily have any "stored" information about the labels.

In the case of the MNIST dataset class however you can sort the dataset from the label list.

For example when you want to list the indices that have the label 5.

mnist = torchvision.datasets.mnist.MNIST("/")
labels = mnist.train_labels
fives = (labels == 5).nonzero()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM