[英]How to select the specific class in cifar-10
I would like to know how to select the specific class in cifar-10.我想知道如何 select 在 cifar-10 中具体的 class。 For example, I want 7, "horse" class in cifar-10.
例如,我想要 cifar-10 中的 7,“马”class。 And I wrote the below code.
我写了下面的代码。 But the obtained data is not what I want because it's wrong shape.
但是获得的数据不是我想要的,因为它的形状错误。
Please enlighten me on the specifics.请赐教具体情况。
from keras.datasets import cifar10
(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (50000, 32, 32, 3), Y_train shape: (50000, 1)
The below code is wrong.下面的代码是错误的。
import numpy as np
filter = np.where(Y_train == 7)
X_train = X_train[filter]
Y_train = Y_train[filter]
print('X_train shape: {0}, Y_train shape: {1}'.format(X_train.shape, Y_train.shape))
X_train shape: (5000, 32, 3), Y_train shape: (5000,)
The expected output is below预期的 output 如下
X_train shape: (5000, 32, 32, 3), Y_train shape: (5000,)
For the slicing, do something like:对于切片,请执行以下操作:
X_train = X_train[filter[0], ...]
Y_train = Y_train[filter[0], ...]
And the shapes would be形状将是
X_train shape: (5000, 32, 32, 3), Y_train shape: (5000, 1)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.