[英]Apply rolling custom function with pandas
該站點中有一些類似的問題,但我找不到針對我的特定問題的解決方案。
我有一個要使用自定義函數處理的數據框(實際函數有更多的預處理,但要點包含在玩具示例fun
中)。
import statsmodels.api as sm
import numpy as np
import pandas as pd
mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data)
def fun(col1, col2, w1=10, w2=2):
return(np.mean(w1 * col1 + w2 * col2))
# This is the behavior I would expect for the full dataset, currently working
mtcars.apply(lambda x: fun(x.cyl, x.mpg), axis=1)
# This was my approach to do the same with a rolling function
mtcars.rolling(3).apply(lambda x: fun(x.cyl, x.mpg))
rolling
版本返回此錯誤:
AttributeError: 'Series' object has no attribute 'cyl'
我想我不完全理解rolling
是如何工作的,因為在我的函數的開頭添加一個打印語句表明fun
的不是獲得完整的數據集而是一個未命名的系列 3。在pandas
中應用這個滾動函數的方法是什么?
以防萬一,我正在運行
>>> pd.__version__
'1.5.2'
看起來這里有一個非常相似的問題,可能與我正在嘗試做的部分重疊。
為了完整起見,下面是我將如何在R
中使用預期輸出執行此操作。
library(dplyr)
fun <- function(col1, col2, w1=10, w2=2){
return(mean(w1*col1 + w2*col2))
}
mtcars %>%
mutate(roll = slider::slide2(.x = cyl,
.y = mpg,
.f = fun,
.before = 1,
.after = 1))
mpg cyl disp hp drat wt qsec vs am gear carb roll
Mazda RX4 21.0 6 160.0 110 3.90 2.620 16.46 0 1 4 4 102
Mazda RX4 Wag 21.0 6 160.0 110 3.90 2.875 17.02 0 1 4 4 96.53333
Datsun 710 22.8 4 108.0 93 3.85 2.320 18.61 1 1 4 1 96.8
Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1 101.9333
Hornet Sportabout 18.7 8 360.0 175 3.15 3.440 17.02 0 0 3 2 105.4667
Valiant 18.1 6 225.0 105 2.76 3.460 20.22 1 0 3 1 107.4
Duster 360 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 97.86667
Merc 240D 24.4 4 146.7 62 3.69 3.190 20.00 1 0 4 2 94.33333
Merc 230 22.8 4 140.8 95 3.92 3.150 22.90 1 0 4 2 90.93333
Merc 280 19.2 6 167.6 123 3.92 3.440 18.30 1 0 4 4 93.2
Merc 280C 17.8 6 167.6 123 3.92 3.440 18.90 1 0 4 4 102.2667
Merc 450SE 16.4 8 275.8 180 3.07 4.070 17.40 0 0 3 3 107.6667
Merc 450SL 17.3 8 275.8 180 3.07 3.730 17.60 0 0 3 3 112.6
Merc 450SLC 15.2 8 275.8 180 3.07 3.780 18.00 0 0 3 3 108.6
Cadillac Fleetwood 10.4 8 472.0 205 2.93 5.250 17.98 0 0 3 4 104
Lincoln Continental 10.4 8 460.0 215 3.00 5.424 17.82 0 0 3 4 103.6667
Chrysler Imperial 14.7 8 440.0 230 3.23 5.345 17.42 0 0 3 4 105
Fiat 128 32.4 4 78.7 66 4.08 2.200 19.47 1 1 4 1 105
Honda Civic 30.4 4 75.7 52 4.93 1.615 18.52 1 1 4 2 104.4667
Toyota Corolla 33.9 4 71.1 65 4.22 1.835 19.90 1 1 4 1 97.2
Toyota Corona 21.5 4 120.1 97 3.70 2.465 20.01 1 0 3 1 100.6
Dodge Challenger 15.5 8 318.0 150 2.76 3.520 16.87 0 0 3 2 101.4667
AMC Javelin 15.2 8 304.0 150 3.15 3.435 17.30 0 0 3 2 109.3333
Camaro Z28 13.3 8 350.0 245 3.73 3.840 15.41 0 0 3 4 111.8
Pontiac Firebird 19.2 8 400.0 175 3.08 3.845 17.05 0 0 3 2 106.5333
Fiat X1-9 27.3 4 79.0 66 4.08 1.935 18.90 1 1 4 1 101.6667
Porsche 914-2 26.0 4 120.3 91 4.43 2.140 16.70 0 1 5 2 95.8
Lotus Europa 30.4 4 95.1 113 3.77 1.513 16.90 1 1 5 2 101.4667
Ford Pantera L 15.8 8 351.0 264 4.22 3.170 14.50 0 1 5 4 103.9333
Ferrari Dino 19.7 6 145.0 175 3.62 2.770 15.50 0 1 5 6 107
Maserati Bora 15.0 8 301.0 335 3.54 3.570 14.60 0 1 5 8 97.4
Volvo 142E 21.4 4 121.0 109 4.11 2.780 18.60 1 1 4 2 96.4
沒有真正優雅的方法來做到這一點。 這是一個建議:
首先安裝numpy_ext
(使用pip install numpy_ext
或pip install numpy_ext --user
)。
其次,您需要單獨計算列並將其連接到原始數據幀:
import statsmodels.api as sm
import pandas as pd
from numpy_ext import rolling_apply as rolling_apply_ext
import numpy as np
mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data).reset_index()
def fun(col1, col2, w1=10, w2=2):
return(w1 * col1 + w2 * col2)
Col= pd.DataFrame(rolling_apply_ext(fun, 3, mtcars.cyl.values, mtcars.mpg.values)).rename(columns={2:'rolling'})
mtcars.join(Col["rolling"])
要得到:
index mpg cyl disp hp drat wt qsec vs am \
0 Mazda RX4 21.0 6 160.0 110 3.90 2.620 16.46 0 1
1 Mazda RX4 Wag 21.0 6 160.0 110 3.90 2.875 17.02 0 1
2 Datsun 710 22.8 4 108.0 93 3.85 2.320 18.61 1 1
3 Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0
4 Hornet Sportabout 18.7 8 360.0 175 3.15 3.440 17.02 0 0
5 Valiant 18.1 6 225.0 105 2.76 3.460 20.22 1 0
6 Duster 360 14.3 8 360.0 245 3.21 3.570 15.84 0 0
7 Merc 240D 24.4 4 146.7 62 3.69 3.190 20.00 1 0
8 Merc 230 22.8 4 140.8 95 3.92 3.150 22.90 1 0
9 Merc 280 19.2 6 167.6 123 3.92 3.440 18.30 1 0
10 Merc 280C 17.8 6 167.6 123 3.92 3.440 18.90 1 0
11 Merc 450SE 16.4 8 275.8 180 3.07 4.070 17.40 0 0
12 Merc 450SL 17.3 8 275.8 180 3.07 3.730 17.60 0 0
13 Merc 450SLC 15.2 8 275.8 180 3.07 3.780 18.00 0 0
14 Cadillac Fleetwood 10.4 8 472.0 205 2.93 5.250 17.98 0 0
15 Lincoln Continental 10.4 8 460.0 215 3.00 5.424 17.82 0 0
16 Chrysler Imperial 14.7 8 440.0 230 3.23 5.345 17.42 0 0
17 Fiat 128 32.4 4 78.7 66 4.08 2.200 19.47 1 1
18 Honda Civic 30.4 4 75.7 52 4.93 1.615 18.52 1 1
19 Toyota Corolla 33.9 4 71.1 65 4.22 1.835 19.90 1 1
20 Toyota Corona 21.5 4 120.1 97 3.70 2.465 20.01 1 0
21 Dodge Challenger 15.5 8 318.0 150 2.76 3.520 16.87 0 0
22 AMC Javelin 15.2 8 304.0 150 3.15 3.435 17.30 0 0
23 Camaro Z28 13.3 8 350.0 245 3.73 3.840 15.41 0 0
24 Pontiac Firebird 19.2 8 400.0 175 3.08 3.845 17.05 0 0
25 Fiat X1-9 27.3 4 79.0 66 4.08 1.935 18.90 1 1
26 Porsche 914-2 26.0 4 120.3 91 4.43 2.140 16.70 0 1
27 Lotus Europa 30.4 4 95.1 113 3.77 1.513 16.90 1 1
28 Ford Pantera L 15.8 8 351.0 264 4.22 3.170 14.50 0 1
29 Ferrari Dino 19.7 6 145.0 175 3.62 2.770 15.50 0 1
30 Maserati Bora 15.0 8 301.0 335 3.54 3.570 14.60 0 1
31 Volvo 142E 21.4 4 121.0 109 4.11 2.780 18.60 1 1
gear carb rolling
0 4 4 NaN
1 4 4 NaN
2 4 1 85.6
3 3 1 102.8
4 3 2 117.4
5 3 1 96.2
6 3 4 108.6
7 4 2 88.8
8 4 2 85.6
9 4 4 98.4
10 4 4 95.6
11 3 3 112.8
12 3 3 114.6
13 3 3 110.4
14 3 4 100.8
15 3 4 100.8
16 3 4 109.4
17 4 1 104.8
18 4 2 100.8
19 4 1 107.8
20 3 1 83.0
21 3 2 111.0
22 3 2 110.4
23 3 4 106.6
24 3 2 118.4
25 4 1 94.6
26 5 2 92.0
27 5 2 100.8
28 5 4 111.6
29 5 6 99.4
30 5 8 110.0
31 4 2 82.8
經過大量搜索和反對爭論。 我發現了一種受此答案啟發的方法
def fun(series, w1=10, w2=2):
col1 = mtcars.loc[series.index, 'cyl']
col2 = mtcars.loc[series.index, 'mpg']
return(np.mean(w1 * col1 + w2 * col2))
mtcars['roll'] = mtcars.rolling(3, center=True, min_periods=0)['mpg'] \
.apply(fun, raw=False)
mtcars
mpg cyl disp hp ... am gear carb roll
Mazda RX4 21.0 6 160.0 110 ... 1 4 4 102.000000
Mazda RX4 Wag 21.0 6 160.0 110 ... 1 4 4 96.533333
Datsun 710 22.8 4 108.0 93 ... 1 4 1 96.800000
Hornet 4 Drive 21.4 6 258.0 110 ... 0 3 1 101.933333
Hornet Sportabout 18.7 8 360.0 175 ... 0 3 2 105.466667
Valiant 18.1 6 225.0 105 ... 0 3 1 107.400000
Duster 360 14.3 8 360.0 245 ... 0 3 4 97.866667
Merc 240D 24.4 4 146.7 62 ... 0 4 2 94.333333
Merc 230 22.8 4 140.8 95 ... 0 4 2 90.933333
Merc 280 19.2 6 167.6 123 ... 0 4 4 93.200000
Merc 280C 17.8 6 167.6 123 ... 0 4 4 102.266667
Merc 450SE 16.4 8 275.8 180 ... 0 3 3 107.666667
Merc 450SL 17.3 8 275.8 180 ... 0 3 3 112.600000
Merc 450SLC 15.2 8 275.8 180 ... 0 3 3 108.600000
Cadillac Fleetwood 10.4 8 472.0 205 ... 0 3 4 104.000000
Lincoln Continental 10.4 8 460.0 215 ... 0 3 4 103.666667
Chrysler Imperial 14.7 8 440.0 230 ... 0 3 4 105.000000
Fiat 128 32.4 4 78.7 66 ... 1 4 1 105.000000
Honda Civic 30.4 4 75.7 52 ... 1 4 2 104.466667
Toyota Corolla 33.9 4 71.1 65 ... 1 4 1 97.200000
Toyota Corona 21.5 4 120.1 97 ... 0 3 1 100.600000
Dodge Challenger 15.5 8 318.0 150 ... 0 3 2 101.466667
AMC Javelin 15.2 8 304.0 150 ... 0 3 2 109.333333
Camaro Z28 13.3 8 350.0 245 ... 0 3 4 111.800000
Pontiac Firebird 19.2 8 400.0 175 ... 0 3 2 106.533333
Fiat X1-9 27.3 4 79.0 66 ... 1 4 1 101.666667
Porsche 914-2 26.0 4 120.3 91 ... 1 5 2 95.800000
Lotus Europa 30.4 4 95.1 113 ... 1 5 2 101.466667
Ford Pantera L 15.8 8 351.0 264 ... 1 5 4 103.933333
Ferrari Dino 19.7 6 145.0 175 ... 1 5 6 107.000000
Maserati Bora 15.0 8 301.0 335 ... 1 5 8 97.400000
Volvo 142E 21.4 4 121.0 109 ... 1 4 2 96.400000
[32 rows x 12 columns]
要按照我的意願執行此操作,需要做幾件事。 raw=False
如果僅調用.index
( False: passes each row or column as a Series to the function.
),將提供對該系列的fun
訪問。 這是愚蠢且低效的,但它確實有效。 我需要我的窗口center=True
。 我還需要NaN
填充可用信息,所以我設置min_periods=0
。
這種方法有幾點我不喜歡:
fun
范圍之外調用mtcars
有潛在的危險,可能會導致錯誤。.loc
的多重索引不能很好地擴展並且性能可能更差(滾動的次數比需要的多)我不知道有一種方法可以通過將單個函數應用於 pandas 數據框來輕松高效地進行此計算,因為您正在計算多行和多列的值。 一種有效的方法是首先計算要為其計算滾動平均值的列,然后計算滾動平均值:
import statsmodels.api as sm
import pandas as pd
mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data)
# Create column
def df_fun(df, col1, col2, w1=10, w2=2):
return w1 * df[col1] + w2 * df[col2]
mtcars['fun_val'] = df_fun(mtcars, 'cyl', 'mpg')
# Calculate rolling average
mtcars['fun_val_r3m'] = mtcars['fun_val'].rolling(3, center=True, min_periods=0).mean()
這給出了正確的答案,並且是有效的,因為每個步驟都應該針對性能進行優化。 我發現像這樣分離行和列計算比你提出的最新方法快大約 10 倍,而且不需要導入 numpy。 如果您不想保留中間計算fun_val
,您可以用滾動平均值fun_val_r3m
覆蓋它。
如果您真的需要在一行中使用apply
來執行此操作,除了您在最新帖子中所做的之外,我不知道還有其他方法。 基於numpy
數組的方法可能會執行得更好,但可讀性較差。
您可以使用以下功能進行滾動應用。 在某些情況下,與 pandas inbuild rolling 相比,它可能會很慢,但具有額外的功能。
函數參數 win_size、min_periods(類似於 pandas,只接受整數輸入)。 另外,參數也用於控制到窗口后,在觀察后將窗口移動到包含。
def roll_apply(df, fn, win_size, min_periods=None, after=None):
if min_periods is None:
min_periods = win_size
else:
assert min_periods >= 1
if after is None:
after = 0
before = win_size - 1 - after
i = np.arange(df.shape[0])
s = np.maximum(i - before, 0)
e = np.minimum(i + after, df.shape[0]) + 1
res = [fn(df.iloc[si:ei]) for si, ei in zip(s, e) if (ei-si) >= min_periods]
idx = df.index[(e-s) >= min_periods]
types = {type(ri) for ri in res}
if len(types) != 1:
return pd.Series(res, index=idx)
t = list(types)[0]
if t == pd.Series:
return pd.DataFrame(res, index=idx)
elif t == pd.DataFrame:
return pd.concat(res, keys=idx)
else:
return pd.Series(res, index=idx)
mtcars['roll'] = roll_apply(mtcars, lambda x: fun(x.cyl, x.mpg), win_size=3, min_periods=1, after=1)
指數 | 英里/加侖 | 圓柱體 | 顯示 | 生命值 | 廢話 | 重量 | 秒 | 對比 | 是 | 齒輪 | 碳水化合物 | 卷 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
馬自達RX4 | 21.0 | 6個 | 160.0 | 110 | 3.9 | 2.62 | 16.46 | 0 | 1個 | 4個 | 4個 | 102.0 |
馬自達 RX4 搖擺 | 21.0 | 6個 | 160.0 | 110 | 3.9 | 2.875 | 17.02 | 0 | 1個 | 4個 | 4個 | 96.53333333333335 |
達特桑710 | 22.8 | 4個 | 108.0 | 93 | 3.85 | 2.32 | 18.61 | 1個 | 1個 | 4個 | 1個 | 96.8 |
大黃蜂 4 驅動器 | 21.4 | 6個 | 258.0 | 110 | 3.08 | 3.215 | 19.44 | 1個 | 0 | 3個 | 1個 | 101.93333333333332 |
大黃蜂 Sportabout | 18.7 | 8個 | 360.0 | 175 | 3.15 | 3.44 | 17.02 | 0 | 0 | 3個 | 2個 | 105.46666666666665 |
英勇 | 18.1 | 6個 | 225.0 | 105 | 2.76 | 3.46 | 20.22 | 1個 | 0 | 3個 | 1個 | 107.40000000000002 |
除塵器 360 | 14.3 | 8個 | 360.0 | 245 | 3.21 | 3.57 | 15.84 | 0 | 0 | 3個 | 4個 | 97.86666666666667 |
商務 240D | 24.4 | 4個 | 146.7 | 62 | 3.69 | 3.19 | 20.0 | 1個 | 0 | 4個 | 2個 | 94.33333333333333 |
商業 230 | 22.8 | 4個 | 140.8 | 95 | 3.92 | 3.15 | 22.9 | 1個 | 0 | 4個 | 2個 | 90.93333333333332 |
商業 280 | 19.2 | 6個 | 167.6 | 123 | 3.92 | 3.44 | 18.3 | 1個 | 0 | 4個 | 4個 | 93.2 |
商務 280C | 17.8 | 6個 | 167.6 | 123 | 3.92 | 3.44 | 18.9 | 1個 | 0 | 4個 | 4個 | 102.26666666666667 |
奔馳 450SE | 16.4 | 8個 | 275.8 | 180 | 3.07 | 4.07 | 17.4 | 0 | 0 | 3個 | 3個 | 107.66666666666667 |
商務 450SL | 17.3 | 8個 | 275.8 | 180 | 3.07 | 3.73 | 17.6 | 0 | 0 | 3個 | 3個 | 112.59999999999998 |
商務 450SLC | 15.2 | 8個 | 275.8 | 180 | 3.07 | 3.78 | 18.0 | 0 | 0 | 3個 | 3個 | 108.60000000000001 |
凱迪拉克弗利特伍德 | 10.4 | 8個 | 472.0 | 205 | 2.93 | 5.25 | 17.98 | 0 | 0 | 3個 | 4個 | 104.0 |
林肯大陸 | 10.4 | 8個 | 460.0 | 215 | 3.0 | 5.424 | 17.82 | 0 | 0 | 3個 | 4個 | 103.66666666666667 |
克萊斯勒帝國 | 14.7 | 8個 | 440.0 | 230 | 3.23 | 5.345 | 17.42 | 0 | 0 | 3個 | 4個 | 105.0 |
菲亞特128 | 32.4 | 4個 | 78.7 | 66 | 4.08 | 2.2 | 19.47 | 1個 | 1個 | 4個 | 1個 | 105.0 |
本田思域 | 30.4 | 4個 | 75.7 | 52 | 4.93 | 1.615 | 18.52 | 1個 | 1個 | 4個 | 2個 | 104.46666666666665 |
豐田花冠 | 33.9 | 4個 | 71.1 | 65 | 4.22 | 1.835 | 19.9 | 1個 | 1個 | 4個 | 1個 | 97.2 |
豐田電暈 | 21.5 | 4個 | 120.1 | 97 | 3.7 | 2.465 | 20.01 | 1個 | 0 | 3個 | 1個 | 100.60000000000001 |
道奇挑戰者 | 15.5 | 8個 | 318.0 | 150 | 2.76 | 3.52 | 16.87 | 0 | 0 | 3個 | 2個 | 101.46666666666665 |
AMC標槍 | 15.2 | 8個 | 304.0 | 150 | 3.15 | 3.435 | 17.3 | 0 | 0 | 3個 | 2個 | 109.33333333333333 |
科邁羅Z28 | 13.3 | 8個 | 350.0 | 245 | 3.73 | 3.84 | 15.41 | 0 | 0 | 3個 | 4個 | 111.8 |
龐蒂亞克火鳥 | 19.2 | 8個 | 400.0 | 175 | 3.08 | 3.845 | 17.05 | 0 | 0 | 3個 | 2個 | 106.53333333333335 |
菲亞特X1-9 | 27.3 | 4個 | 79.0 | 66 | 4.08 | 1.935 | 18.9 | 1個 | 1個 | 4個 | 1個 | 101.66666666666667 |
保時捷 914-2 | 26.0 | 4個 | 120.3 | 91 | 4.43 | 2.14 | 16.7 | 0 | 1個 | 5個 | 2個 | 95.8 |
歐洲蓮花 | 30.4 | 4個 | 95.1 | 113 | 3.77 | 1.513 | 16.9 | 1個 | 1個 | 5個 | 2個 | 101.46666666666665 |
福特 Pantera L | 15.8 | 8個 | 351.0 | 264 | 4.22 | 3.17 | 14.5 | 0 | 1個 | 5個 | 4個 | 103.93333333333332 |
法拉利迪諾 | 19.7 | 6個 | 145.0 | 175 | 3.62 | 2.77 | 15.5 | 0 | 1個 | 5個 | 6個 | 107.0 |
瑪莎拉蒂寶來 | 15.0 | 8個 | 301.0 | 335 | 3.54 | 3.57 | 14.6 | 0 | 1個 | 5個 | 8個 | 97.39999999999999 |
沃爾沃142E | 21.4 | 4個 | 121.0 | 109 | 4.11 | 2.78 | 18.6 | 1個 | 1個 | 4個 | 2個 | 96.4 |
您可以在 roll_apply 函數中傳遞更復雜的函數。 下面是幾個例子
roll_apply(mtcars, lambda d: pd.Series({'A': d.sum().sum(), 'B': d.std().std()}), win_size=3, min_periods=1, after=1) # Simple example to illustrate use case
roll_apply(mtcars, lambda d: d, win_size=3, min_periods=3, after=1) # This will return rolling dataframe
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.