简体   繁体   中英

Pandas sample from df keeping balance of groups

Lets generate some dataframe:

import pandas as pd
categs = ['cat1'] * 600 + ['cat2'] * 300 + ['cat3'] * 100
subcats = ['sub1', 'sub2', 'sub2', 'sub3', 'sub3', 'sub4', 'sub4', 'sub4', 'sub4', 'sub4'] * 100
subcats[0] = 'subX'
vals = range(1000)
df = pd.DataFrame({
   'category': categs,
   'subcategory': subcats,
   'values': vals
})

So let's look at amount of rows by category and subcategory:

print(df.groupby(['category', 'subcategory']).size())

we got

>>>
category  subcategory
cat1      sub1            59
          sub2           120
          sub3           120
          sub4           300
          subX             1
cat2      sub1            30
          sub2            60
          sub3            60
          sub4           150
cat3      sub1            10
          sub2            20
          sub3            20
          sub4            50
dtype: int64

This is a dataframe of 1000 elements. There are 600 elements of cat1, 300 of cat2 and 100 of cat3. What I want is to reduce this dataframe from 1000 to let's say 60 rows so
1) each category has same amount of rows (20 in our case, which equals 60 / (number of categories) )
2) proportion of each subcategory in a category is kept
3) if we have small number of subcategory items it still stays in category (there is only one 'subX' in cat1, we need to keep it even if it's proportion was 1/600 for cat1).

So when we create our new df I would like to receive something like this:

print(newdf.groupby(['category', 'subcategory']).size())


category  subcategory
cat1      sub1            2
          sub2           4
          sub3           4
          sub4           10
          subX             1
cat2      sub1            2
          sub2            4
          sub3            4
          sub4           10
cat3      sub1            2
          sub2            4
          sub3            4
          sub4            10
dtype: int64

In this case there are 21 element for cat1, but it is not a big deal, the main idea is that proportion of subcategories are saved and amount of rows is around targeted number 20.

You can find the number of rows that you should keep per subcategory, and keep only the rows with cumcount below that number:

# total (approximate) number of rows to keep
n = 60

# number of rows per category
n_per_cat = n / df['category'].nunique()

# number of rows per subcategory
g_subcat = df.groupby(['category', 'subcategory'])
z = g_subcat['category'].size()
n_per_subcat = np.ceil(z / z.sum(level=0) * n_per_cat)

# output
df_out = (df
          .assign(i=g_subcat.cumcount())
          .merge(n_per_subcat.rename('n').reset_index())
          .query('i < n')
          .drop(columns=['i', 'n']))

# test
df_out.groupby(['category', 'subcategory']).size()

Output:

category  subcategory
cat1      sub1            2
          sub2            4
          sub3            4
          sub4           10
          subX            1
cat2      sub1            2
          sub2            4
          sub3            4
          sub4           10
cat3      sub1            2
          sub2            4
          sub3            4
          sub4           10

PS And to make it random, you can, of course, shuffle the dataframe before all this with:

df = df.sample(frac=1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM