简体   繁体   English

如何 select 在 cifar-10 中具体 class

[英]How to select the specific class in cifar-10

I would like to know how to select the specific class in cifar-10.我想知道如何 select 在 cifar-10 中具体的 class。 For example, I want 7, "horse" class in cifar-10.例如,我想要 cifar-10 中的 7,“马”class。 And I wrote the below code.我写了下面的代码。 But the obtained data is not what I want because it's wrong shape.但是获得的数据不是我想要的,因为它的形状错误。

Please enlighten me on the specifics.请赐教具体情况。

from keras.datasets import cifar10

(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()

print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (50000, 32, 32, 3), Y_train shape: (50000, 1)

The below code is wrong.下面的代码是错误的。

import numpy as np

filter = np.where(Y_train == 7)

X_train = X_train[filter]
Y_train = Y_train[filter]

print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (5000, 32, 3), Y_train shape: (5000,)

The expected output is below预期的 output 如下

X_train shape: (5000, 32, 32, 3), Y_train shape: (5000,)

For the slicing, do something like:对于切片,请执行以下操作:

X_train = X_train[filter[0], ...]
Y_train = Y_train[filter[0], ...]

And the shapes would be形状将是

X_train shape: (5000, 32, 32, 3), Y_train shape: (5000, 1)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM