简体   繁体   中英

python equivalent of group_by, mutate using cur_group() (i.e. value of grouping variable)

If I have a frame d and a function f() in R that looks like these:

df = data.frame(
  group=c("cat","fish","horse","cat","fish","horse","cat","horse"),
  x = c(1,4,7,2,5,8,3,9)
)
f <- function(animal,x) {
  nchar(animal) + mean(x)*(x+1)
}

applying f() to each group to add new column with the result of f() is straightforward:

library(dplyr)
mutate(group_by(df,group),result=f(cur_group(),x))

Output:

  group     x result
  <chr> <dbl>  <dbl>
1 cat       1    7  
2 fish      4   26.5
3 horse     7   69  
4 cat       2    9  
5 fish      5   31  
6 horse     8   77  
7 cat       3   11  
8 horse     9   85  

What is the correct way to do the same in python if d is a pandas.DataFrame ?

import numpy as np
import pandas as pd
d = pd.DataFrame({"group":["cat","fish","horse","cat","fish","horse","cat","horse"], "x":[1,4,7,2,5,8,3,9]})

def f(animal,x):
    return [np.mean(x)*(k+1) + len(animal) for k in x]

I know I can get the "correct" values like this:

d.groupby("group").apply(lambda g: f(g.name,g.x))

and can "explode" that into a single Series using .explode() , but what is the correct way to get the values added to the frame, in the correct order, etc:

Expected Output (python)

   group  x  result
0    cat  1     7.0
1   fish  4    26.5
2  horse  7    69.0
3    cat  2     9.0
4   fish  5    31.0
5  horse  8    77.0
6    cat  3    11.0
7  horse  9    85.0

The pandas version would follow a different logic.

Instead of putting everything in a function with apply , one would rather keep the operations vectorial. You can broadcast a scalar output to all members of a group with GroupBy.transform :

g = d.groupby('group')

d['result'] = g['x'].transform('mean').mul(d['x'].add(1))+d['group'].str.len()

If you really want to use apply, use vectorial code inside the function:

def f(g):
    return g['x'].mean()*(g['x']+1)+g['group'].str.len()

d['result'] = d.groupby("group", group_keys=False).apply(f)

output:

   group  x  result
0    cat  1     7.0
1   fish  4    26.5
2  horse  7    69.0
3    cat  2     9.0
4   fish  5    31.0
5  horse  8    77.0
6    cat  3    11.0
7  horse  9    85.0

We have transform

d['out'] = d.groupby('group')['x'].transform('mean').mul(d['x'].add(1)) + d['group'].str.len()
Out[540]: 
0     7.0
1    26.5
2    69.0
3     9.0
4    31.0
5    77.0
6    11.0
7    85.0
dtype: float64

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM