简体   繁体   English

如何将CIFAR数据集转换为与MNIST相同的格式

[英]How to convert CIFAR dataset into the same format as MNIST

I'm trying to run a machine learning algorithm on two different datasets. 我正在尝试在两个不同的数据集上运行机器学习算法。 However, the format for the y values are different between the datasets. 但是,数据集之间y值的格式不同。

from keras.datasets import mnist, cifar10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print([y_train[i] for i in range(10)])
'''
[5, 0, 4, 1, 9, 2, 1, 3, 1, 4]
'''
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print([y_train[i] for i in range(10)])
'''
[array([6], dtype=uint8), array([9], dtype=uint8), array([9], dtype=uint8), array([4], dtype=uint8), array([1], dtype=uint8), array([1], dtype=uint8), array([2], dtype=uint8), array([7], dtype=uint8), array([8], dtype=uint8), array([3], dtype=uint8)]
'''

The documentation on Keras says that the format for MNIST is Keras上的文档说MNIST的格式是

y_train, y_test: uint8 array of digit labels (integers in range 0-9) with shape (num_samples,) y_train,y_test:uint8形状为(num_samples,)的数字标签数组(范围为0-9的整数)

and that the format for CIFAR is 而且CIFAR的格式是

y_train, y_test: uint8 array of category labels (integers in range 0-9) with shape (num_samples,) y_train,y_test:uint8形状为(num_samples,)的类别标签(0-9范围内的整数)数组

To me these seem like they should be the exact same format. 在我看来,它们似乎应该是完全相同的格式。 So, I have two questions: 因此,我有两个问题:

  1. How would I tell, from the documentation, that they actually have different formats? 我如何从文档中得知它们实际上具有不同的格式? (If this is impossible, then just say so.) (如果这是不可能的,那就这样说。)

  2. How can I convert the CIFAR dataset to be in the same format as mnist? 如何将CIFAR数据集转换为与mnist相同的格式? (My algorithm currently works on MNIST.) (我的算法目前适用于MNIST。)

The issue is that y have slightly different shapes in both datasets, its (60000,) for MNIST, but (50000, 1) in CIFAR10. 问题在于,在两个数据集中y的形状略有不同,MNIST的形状为(60000,) ,而CIFAR10中的形状为(50000, 1) 50000,1)。 The extra dimension should not cause any problem, but in any case you can get rid of with with: 额外的尺寸应该不会造成任何问题,但是在任何情况下,您都可以通过以下方法摆脱:

y = np.squeeze(y)

Then y's shape will be (50000,) . 则y的形状为(50000,)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM