简体   繁体   中英

How do I select only a specific digit from the MNIST dataset provided by Keras?

I'm currently training a Feedforward Neural Network on the MNIST data set using Keras. I'm loading the data set using the format

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

but then I only want to train my model using digit 0 and 4 not all of them. How do I select only the 2 digits? I am fairly new to python and can figure out how to filter the mnist dataset...

Y_train and Y_test give you the labels of images, you can use them with numpy.where to filter out a subset of labels with 0's and 4's. All your variables are numpy arrays, so you can simply do;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

and you can use these filters to get the subset of arrays by index.

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

If you are interested in more than 2 labels, the syntax can get hairy with where and or. So you can also use numpy.isin to create masks.

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

You can use these masks for boolean indexing, same as before.

you have the label files along with train and test:

train_images = mnist.train_images()
train_labels = mnist.train_labels()

test_images = mnist.test_images()
test_labels = mnist.test_labels()

you can use them together with a simple list comprehension to filter your dataset

zero_four_test = [test_images[key] for (key, label) in enumerate(test_labels) if int(label) == 0 or int(label) == 4]

using Y_train = Y_train[train_mask] raises InvalidArgumentError when the digits aren't consecutive and start with 0 (keras expects a consecutive label range starting from 0)

the solution (for two digits) is:

train_mask = np.isin(Y_train, [2,8])
test_mask = np.isin(Y_test, [2,8])

X_train, Y_train = X_train[train_mask], np.array(Y_train[train_mask] == 8)
X_test, Y_test = X_test[test_mask], np.array(Y_test[test_mask] == 8)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM