简体   繁体   中英

Filter in group by in pandas

I have the following dataframe

 df = pd.DataFrame(dict(g = [0, 0, 1, 1, 2, 2], x = [0, 1, 1, 2, 2, 3]))

And I want to obtain a subset of this dataframe with the groups from g such that mean(x) > 0.6 . That is, I want a filter_group operation to obtain the following dataframe:

>>> filtered_df = filter_group(df)
>>> filtered_df
   g  x
2  1  1
3  1  2
4  2  2
5  2  3

Is there an easy way to do this in pandas? This is similar to the having operation in SQL, but a bit different since I want to obtain a dataframe with the same schema, but less rows.


For R users, what I'm trying to do is:

library(dplyr)
df <- tibble(
  g = c(0, 0, 1, 1, 2, 2),
  x = c(0, 1, 1, 2, 2, 3)
)

df %>% 
  group_by(g) %>% 
  filter(mean(x) > 0.6)

Use GroupBy.transform for reepat aggregate values per groups for possible filter original values in boolean indexing :

df[df.groupby('g')['x'].transform('mean') > 0.6]

This solution is better if large DataFrame or many groups if performance is important:

np.random.seed(2020)

N = 10000
df = pd.DataFrame(dict(g = np.random.randint(1000, size=N), 
                       x = np.random.randint(10000, size=N)))
print (df)
        

In [89]: %timeit df[df.groupby('g')['x'].transform('mean') > 0.6]
2.01 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [90]: %timeit df.groupby('g').filter(lambda df: df['x'].mean() > 0.6)
145 ms ± 2.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

By looking at it, an alternative way to do it is by using the filter method:

df.groupby('g').filter(lambda df: df['x'].mean() > 0.6)

To me this has the following advantages:

  • It generalizes easily if many columns are involved in the filter.
  • It uses the chained pandas paradigm which I am fond of.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM