简体   繁体   中英

How to get top n records from each category in a Python dataframe?

The data is sorted in descending order on column 'id' in the following dataframe -

id   Name     version     copies   price
6    MSFT       10.0        5       100   
6    TSLA       10.0        10      200
6    ORCL       10.0        15      300

5    MSFT       10.0        20      400
5    TSLA       10.0        25      500
5    ORCL       10.0        30      600

4    MSFT       10.0        35      700
4    TSLA       10.0        40      800
4    ORCL       10.0        45      900

3    MSFT       5.0         50      1000 
3    TSLA       5.0         55      1100
3    ORCL       5.0         60      1200

2    MSFT       5.0         65      1300
2    TSLA       5.0         70      1400
2    ORCL       5.0         75      1500

1    MSFT       15.0        80      1600
1    TSLA       15.0        85      1700
1    ORCL       15.0        90      1800
...

Based on the input 'n', I would like to filter above data such that, if input is '2', the resulting dataframe should look like -

Name     version     copies   price
MSFT       10.0        5       100   
TSLA       10.0        10      200
ORCL       10.0        15      300

MSFT       10.0        20      400
TSLA       10.0        25      500
ORCL       10.0        30      600

MSFT       5.0         50      1000 
TSLA       5.0         55      1100
ORCL       5.0         60      1200

MSFT       5.0         65      1300
TSLA       5.0         70      1400
ORCL       5.0         75      1500

MSFT       15.0        80      1600
TSLA       15.0        85      1700
ORCL       15.0        90      1800

Basically, only the top 'n' groups of 'id' for a specific version should be present in the resulting dataframe. If a version has id's < n (eg in version 15.0 there is only one group with id = 1), then all the groups of id's should be present.

I tried using groupy and head , but it didn't work for me. I absolutely have no other clue in getting this to work.

I really appreciate any help with this, thank you.

you can use groupby.transform on the column version, and factorize the column id to have an incremental value (from 0 to ...) for each id per group, then compare to your n and use loc with this mask to select the wanted rows.

n = 2
print(df.loc[df.groupby('version')['id'].transform(lambda x: pd.factorize(x)[0])<n])
    id  Name  version  copies  price
0    6  MSFT     10.0       5    100
1    6  TSLA     10.0      10    200
2    6  ORCL     10.0      15    300
3    5  MSFT     10.0      20    400
4    5  TSLA     10.0      25    500
5    5  ORCL     10.0      30    600
9    3  MSFT      5.0      50   1000
10   3  TSLA      5.0      55   1100
11   3  ORCL      5.0      60   1200
12   2  MSFT      5.0      65   1300
13   2  TSLA      5.0      70   1400
14   2  ORCL      5.0      75   1500
15   1  MSFT     15.0      80   1600
16   1  TSLA     15.0      85   1700
17   1  ORCL     15.0      90   1800

Another option is to use groupby.head once you drop_duplicated to keep unique version-id couples. then use select version-id in a merge .

df.merge(df[['version','id']].drop_duplicates().groupby('version').head(n))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM