简体   繁体   中英

How to select a specific number of each class from the MNIST dataset

I'm using tensorflow to work on Mnist. I need to train my network with a specific number of data of each class (for example 500 number of sample of each digit). I've found how to sort the DB with class labels .

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]

but how can I select 500 number of each one and then combine them and the shuffle?

If you have all in one DataFrame then you can groupby label and then get head or tail

import pandas as pd

df = pd.DataFrame({
    'X1': [1,2,3,4,5,6,7,8,9,10,11,12],
    'X2': [21,22,23,24,25,26,27,28,29,30,31,32],
    'label': ['a','a','a','a','b','b','b','b','c','c','c','c']
})

groups = df.groupby('label')

df2 = groups.head(2)    
#df2 = groups.apply(lambda x:x[:2]) # the same as head(2)
#df2 = groups.apply(lambda x:x.sample(frac=1)[:2]) # shuffled before get values

print(df2)

Result

   X1  X2 label
0   1  21     a
1   2  22     a
4   5  25     b
5   6  26     b
8   9  29     c
9  10  30     c

And after that you can shuffle it and split it into X_train , y_train

df2 = df2.sample(frac=1).reset_index(drop=True)

X_train = df2[['X1','X2']]
y_train = df2['label']

print(X_train)
print(y_train)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM