繁体   English   中英

如何从 Keras 提供的 MNIST 数据集中仅选择特定数字?

[英]How do I select only a specific digit from the MNIST dataset provided by Keras?

我目前正在使用 Keras 在 MNIST 数据集上训练前馈神经网络。 我正在使用格式加载数据集

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

但后来我只想使用数字 0 和 4 来训练我的模型,而不是全部。 如何只选择2位数? 我对 python 相当陌生,可以弄清楚如何过滤 mnist 数据集......

Y_trainY_test为您提供图像标签,您可以将它们与numpy.where一起使用以过滤掉带有 0 和 4 的标签子集。 你所有的变量都是 numpy 数组,所以你可以简单地做;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

您可以使用这些过滤器按索引获取数组的子集。

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

如果您对 2 个以上的标签感兴趣,那么 where 和 or 的语法可能会很麻烦。 所以你也可以使用numpy.isin来创建掩码。

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

您可以像以前一样使用这些掩码进行布尔索引。

你有标签文件以及训练和测试:

train_images = mnist.train_images()
train_labels = mnist.train_labels()

test_images = mnist.test_images()
test_labels = mnist.test_labels()

您可以将它们与简单的列表理解一起使用来过滤数据集

zero_four_test = [test_images[key] for (key, label) in enumerate(test_labels) if int(label) == 0 or int(label) == 4]

当数字不连续且从 0 开始时,使用Y_train = Y_train[train_mask]会引发InvalidArgumentError (keras 期望从 0 开始的连续标签范围)

解决方案(两位数)是:

train_mask = np.isin(Y_train, [2,8])
test_mask = np.isin(Y_test, [2,8])

X_train, Y_train = X_train[train_mask], np.array(Y_train[train_mask] == 8)
X_test, Y_test = X_test[test_mask], np.array(Y_test[test_mask] == 8)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM