簡體   English   中英

如何從 MNIST 數據集中選擇每個類的特定數量

[英]How to select a specific number of each class from the MNIST dataset

我正在使用 tensorflow 來處理 Mnist。 我需要用每個類的特定數量的數據(例如每個數字的 500 個樣本)來訓練我的網絡。 我找到了如何使用類標簽對數據庫進行排序

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]

但是我怎樣才能選擇每一個的 500 個號碼,然后將它們和洗牌結合起來?

如果你把所有在一個DataFrame ,那么你可以groupby標簽,然后讓headtail

import pandas as pd

df = pd.DataFrame({
    'X1': [1,2,3,4,5,6,7,8,9,10,11,12],
    'X2': [21,22,23,24,25,26,27,28,29,30,31,32],
    'label': ['a','a','a','a','b','b','b','b','c','c','c','c']
})

groups = df.groupby('label')

df2 = groups.head(2)    
#df2 = groups.apply(lambda x:x[:2]) # the same as head(2)
#df2 = groups.apply(lambda x:x.sample(frac=1)[:2]) # shuffled before get values

print(df2)

結果

   X1  X2 label
0   1  21     a
1   2  22     a
4   5  25     b
5   6  26     b
8   9  29     c
9  10  30     c

之后,您可以將其洗牌並將其拆分為X_trainy_train

df2 = df2.sample(frac=1).reset_index(drop=True)

X_train = df2[['X1','X2']]
y_train = df2['label']

print(X_train)
print(y_train)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM