[英]Calculating rolling mean on dataframe column when grouping by IDs in Python
鉴于df:
df = pd.DataFrame([{'a':'id1', 'b':10}, {'a':'id2', 'b':20},
{'a':'id1', 'b':11}, {'a':'id2', 'b':21},
{'a':'id3', 'b':12}, {'a':'id3', 'b':'NaN'},
{'a':'id1', 'b':0}, {'a':'id2', 'b':23},
{'a':'id1', 'b':0}, {'a':'id5', 'b':0}, {'a':'id4', 'b':10}, {'a':'id2', 'b':20},
{'a':'id4', 'b':11}, {'a':'id2', 'b':21},
{'a':'id1', 'b':12}, {'a':'id2', 'b':'NaN'},
{'a':'id3', 'b':0}, {'a':'id4', 'b':23},
{'a':'id1', 'b':0}, {'a':'id2', 'b':0}])
我正在为存储在列“ a”中的给定“ id”计算滚动平均值和列“ b”中先前值的最大值。
我正在使用的当前代码只是在a列中的两种id类型之间交替给出正确的移位,因此,在现实生活中一旦添加了其他id,该方法将无法正常工作:
df['rolling_mean_2'] = (df.assign(b=df.b.shift())
.groupby('a')['b']
.rolling(window=2, min_periods=2)
.mean()
.sort_index(level=1)
.shift()
.values)
df['rolling_mean_last'] = (df.assign(b=df.b.shift())
.groupby('a')['b']
.rolling(window=1, min_periods=1)
.mean()
.sort_index(level=1)
.shift()
.values)
df['rolling_max_4'] = (df.assign(b=df.b.shift())
.groupby('a')['b']
.rolling(window=4, min_periods=4)
.max()
.sort_index(level=1)
.shift()
.values)
输出以下内容:
a b rolling_mean_2 rolling_mean_last rolling_max_4
0 id1 10 NaN NaN NaN
1 id2 20 NaN NaN NaN
2 id1 11 NaN 10.0 NaN
3 id2 21 NaN 20.0 NaN
4 id3 12 10.5 11.0 NaN
5 id3 NaN NaN 21.0 NaN
6 id1 0 16.5 12.0 NaN
7 id2 23 NaN NaN NaN
8 id1 0 5.5 0.0 NaN
9 id5 0 NaN 23.0 NaN
10 id4 10 NaN 0.0 NaN
11 id2 20 NaN 0.0 NaN
12 id4 11 5.0 10.0 11.0
13 id2 21 10.0 20.0 NaN
14 id1 12 10.5 11.0 11.0
15 id2 NaN 22.0 21.0 NaN
16 id3 0 11.5 12.0 12.0
17 id4 23 NaN NaN NaN
18 id1 0 10.0 0.0 NaN
19 id2 0 22.0 23.0 NaN
我的预期输出是:
a b rolling_mean_2 rolling_mean_last rolling_max_4
0 id1 10 NaN NaN NaN
1 id2 20 NaN NaN NaN
2 id1 11 NaN 10.0 NaN
3 id2 21 NaN 20.0 NaN
4 id3 12 NaN NaN NaN
5 id3 NaN NaN 12.0 NaN
6 id1 0 10.5 11.0 NaN
7 id2 23 20.5 21.0 NaN
8 id1 0 10.75 0.0 NaN
9 id5 0 NaN NaN NaN
10 id4 10 NaN NaN NaN
11 id2 20 22 21.0 NaN
12 id4 11 5.0 NaN NaN
13 id2 21 21.5 23.0 23.0
使用:df ['rolling_mean_last'] = df.groupby('a')['b']。apply(lambda x:x.rolling(window = 1,min_periods = 1).mean()。shift())
能够输出预期的df。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.