簡體   English   中英

用熊貓應用滾動自定義功能

[英]Apply rolling custom function with pandas

該站點中有一些類似的問題,但我找不到針對我的特定問題的解決方案。

我有一個要使用自定義函數處理的數據框(實際函數有更多的預處理,但要點包含在玩具示例fun中)。

import statsmodels.api as sm
import numpy as np
import pandas as pd
mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data)

def fun(col1, col2, w1=10, w2=2):
    return(np.mean(w1 * col1 + w2 * col2))

# This is the behavior I would expect for the full dataset, currently working
mtcars.apply(lambda x: fun(x.cyl, x.mpg), axis=1)

# This was my approach to do the same with a rolling function
mtcars.rolling(3).apply(lambda x: fun(x.cyl, x.mpg))

rolling版本返回此錯誤:

AttributeError: 'Series' object has no attribute 'cyl'

我想我不完全理解rolling是如何工作的,因為在我的函數的開頭添加一個打印語句表明fun的不是獲得完整的數據集而是一個未命名的系列 3。在pandas中應用這個滾動函數的方法是什么?

以防萬一,我正在運行

>>> pd.__version__
'1.5.2'

更新

看起來這里有一個非常相似的問題,可能與我正在嘗試做的部分重疊。

為了完整起見,下面是我將如何在R中使用預期輸出執行此操作。

library(dplyr)

fun <- function(col1, col2, w1=10, w2=2){
  return(mean(w1*col1 + w2*col2))
}

mtcars %>% 
  mutate(roll = slider::slide2(.x = cyl,
                               .y = mpg, 
                               .f = fun, 
                               .before = 1, 
                               .after = 1))


                     mpg cyl  disp  hp drat    wt  qsec vs am gear carb     roll
Mazda RX4           21.0   6 160.0 110 3.90 2.620 16.46  0  1    4    4      102
Mazda RX4 Wag       21.0   6 160.0 110 3.90 2.875 17.02  0  1    4    4 96.53333
Datsun 710          22.8   4 108.0  93 3.85 2.320 18.61  1  1    4    1     96.8
Hornet 4 Drive      21.4   6 258.0 110 3.08 3.215 19.44  1  0    3    1 101.9333
Hornet Sportabout   18.7   8 360.0 175 3.15 3.440 17.02  0  0    3    2 105.4667
Valiant             18.1   6 225.0 105 2.76 3.460 20.22  1  0    3    1    107.4
Duster 360          14.3   8 360.0 245 3.21 3.570 15.84  0  0    3    4 97.86667
Merc 240D           24.4   4 146.7  62 3.69 3.190 20.00  1  0    4    2 94.33333
Merc 230            22.8   4 140.8  95 3.92 3.150 22.90  1  0    4    2 90.93333
Merc 280            19.2   6 167.6 123 3.92 3.440 18.30  1  0    4    4     93.2
Merc 280C           17.8   6 167.6 123 3.92 3.440 18.90  1  0    4    4 102.2667
Merc 450SE          16.4   8 275.8 180 3.07 4.070 17.40  0  0    3    3 107.6667
Merc 450SL          17.3   8 275.8 180 3.07 3.730 17.60  0  0    3    3    112.6
Merc 450SLC         15.2   8 275.8 180 3.07 3.780 18.00  0  0    3    3    108.6
Cadillac Fleetwood  10.4   8 472.0 205 2.93 5.250 17.98  0  0    3    4      104
Lincoln Continental 10.4   8 460.0 215 3.00 5.424 17.82  0  0    3    4 103.6667
Chrysler Imperial   14.7   8 440.0 230 3.23 5.345 17.42  0  0    3    4      105
Fiat 128            32.4   4  78.7  66 4.08 2.200 19.47  1  1    4    1      105
Honda Civic         30.4   4  75.7  52 4.93 1.615 18.52  1  1    4    2 104.4667
Toyota Corolla      33.9   4  71.1  65 4.22 1.835 19.90  1  1    4    1     97.2
Toyota Corona       21.5   4 120.1  97 3.70 2.465 20.01  1  0    3    1    100.6
Dodge Challenger    15.5   8 318.0 150 2.76 3.520 16.87  0  0    3    2 101.4667
AMC Javelin         15.2   8 304.0 150 3.15 3.435 17.30  0  0    3    2 109.3333
Camaro Z28          13.3   8 350.0 245 3.73 3.840 15.41  0  0    3    4    111.8
Pontiac Firebird    19.2   8 400.0 175 3.08 3.845 17.05  0  0    3    2 106.5333
Fiat X1-9           27.3   4  79.0  66 4.08 1.935 18.90  1  1    4    1 101.6667
Porsche 914-2       26.0   4 120.3  91 4.43 2.140 16.70  0  1    5    2     95.8
Lotus Europa        30.4   4  95.1 113 3.77 1.513 16.90  1  1    5    2 101.4667
Ford Pantera L      15.8   8 351.0 264 4.22 3.170 14.50  0  1    5    4 103.9333
Ferrari Dino        19.7   6 145.0 175 3.62 2.770 15.50  0  1    5    6      107
Maserati Bora       15.0   8 301.0 335 3.54 3.570 14.60  0  1    5    8     97.4
Volvo 142E          21.4   4 121.0 109 4.11 2.780 18.60  1  1    4    2     96.4

沒有真正優雅的方法來做到這一點。 這是一個建議:

首先安裝numpy_ext (使用pip install numpy_extpip install numpy_ext --user )。

其次,您需要單獨計算列並將其連接到原始數據幀:

import statsmodels.api as sm
import pandas as pd
from numpy_ext import rolling_apply as rolling_apply_ext

import numpy as np

mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data).reset_index()
def fun(col1, col2, w1=10, w2=2):
    return(w1 * col1 + w2 * col2)

Col= pd.DataFrame(rolling_apply_ext(fun, 3, mtcars.cyl.values, mtcars.mpg.values)).rename(columns={2:'rolling'})


mtcars.join(Col["rolling"])

要得到:

                  index   mpg  cyl   disp   hp  drat     wt   qsec  vs  am  \
0             Mazda RX4  21.0    6  160.0  110  3.90  2.620  16.46   0   1   
1         Mazda RX4 Wag  21.0    6  160.0  110  3.90  2.875  17.02   0   1   
2            Datsun 710  22.8    4  108.0   93  3.85  2.320  18.61   1   1   
3        Hornet 4 Drive  21.4    6  258.0  110  3.08  3.215  19.44   1   0   
4     Hornet Sportabout  18.7    8  360.0  175  3.15  3.440  17.02   0   0   
5               Valiant  18.1    6  225.0  105  2.76  3.460  20.22   1   0   
6            Duster 360  14.3    8  360.0  245  3.21  3.570  15.84   0   0   
7             Merc 240D  24.4    4  146.7   62  3.69  3.190  20.00   1   0   
8              Merc 230  22.8    4  140.8   95  3.92  3.150  22.90   1   0   
9              Merc 280  19.2    6  167.6  123  3.92  3.440  18.30   1   0   
10            Merc 280C  17.8    6  167.6  123  3.92  3.440  18.90   1   0   
11           Merc 450SE  16.4    8  275.8  180  3.07  4.070  17.40   0   0   
12           Merc 450SL  17.3    8  275.8  180  3.07  3.730  17.60   0   0   
13          Merc 450SLC  15.2    8  275.8  180  3.07  3.780  18.00   0   0   
14   Cadillac Fleetwood  10.4    8  472.0  205  2.93  5.250  17.98   0   0   
15  Lincoln Continental  10.4    8  460.0  215  3.00  5.424  17.82   0   0   
16    Chrysler Imperial  14.7    8  440.0  230  3.23  5.345  17.42   0   0   
17             Fiat 128  32.4    4   78.7   66  4.08  2.200  19.47   1   1   
18          Honda Civic  30.4    4   75.7   52  4.93  1.615  18.52   1   1   
19       Toyota Corolla  33.9    4   71.1   65  4.22  1.835  19.90   1   1   
20        Toyota Corona  21.5    4  120.1   97  3.70  2.465  20.01   1   0   
21     Dodge Challenger  15.5    8  318.0  150  2.76  3.520  16.87   0   0   
22          AMC Javelin  15.2    8  304.0  150  3.15  3.435  17.30   0   0   
23           Camaro Z28  13.3    8  350.0  245  3.73  3.840  15.41   0   0   
24     Pontiac Firebird  19.2    8  400.0  175  3.08  3.845  17.05   0   0   
25            Fiat X1-9  27.3    4   79.0   66  4.08  1.935  18.90   1   1   
26        Porsche 914-2  26.0    4  120.3   91  4.43  2.140  16.70   0   1   
27         Lotus Europa  30.4    4   95.1  113  3.77  1.513  16.90   1   1   
28       Ford Pantera L  15.8    8  351.0  264  4.22  3.170  14.50   0   1   
29         Ferrari Dino  19.7    6  145.0  175  3.62  2.770  15.50   0   1   
30        Maserati Bora  15.0    8  301.0  335  3.54  3.570  14.60   0   1   
31           Volvo 142E  21.4    4  121.0  109  4.11  2.780  18.60   1   1   

    gear  carb  rolling  
0      4     4      NaN  
1      4     4      NaN  
2      4     1     85.6  
3      3     1    102.8  
4      3     2    117.4  
5      3     1     96.2  
6      3     4    108.6  
7      4     2     88.8  
8      4     2     85.6  
9      4     4     98.4  
10     4     4     95.6  
11     3     3    112.8  
12     3     3    114.6  
13     3     3    110.4  
14     3     4    100.8  
15     3     4    100.8  
16     3     4    109.4  
17     4     1    104.8  
18     4     2    100.8  
19     4     1    107.8  
20     3     1     83.0  
21     3     2    111.0  
22     3     2    110.4  
23     3     4    106.6  
24     3     2    118.4  
25     4     1     94.6  
26     5     2     92.0  
27     5     2    100.8  
28     5     4    111.6  
29     5     6     99.4  
30     5     8    110.0  
31     4     2     82.8  

經過大量搜索和反對爭論。 我發現了一種受此答案啟發的方法

def fun(series, w1=10, w2=2):
  col1 = mtcars.loc[series.index, 'cyl']
  col2 = mtcars.loc[series.index, 'mpg']
  return(np.mean(w1 * col1 + w2 * col2))

mtcars['roll'] = mtcars.rolling(3, center=True, min_periods=0)['mpg'] \
                       .apply(fun, raw=False)
mtcars
                      mpg  cyl   disp   hp  ...  am  gear  carb        roll
Mazda RX4            21.0    6  160.0  110  ...   1     4     4  102.000000
Mazda RX4 Wag        21.0    6  160.0  110  ...   1     4     4   96.533333
Datsun 710           22.8    4  108.0   93  ...   1     4     1   96.800000
Hornet 4 Drive       21.4    6  258.0  110  ...   0     3     1  101.933333
Hornet Sportabout    18.7    8  360.0  175  ...   0     3     2  105.466667
Valiant              18.1    6  225.0  105  ...   0     3     1  107.400000
Duster 360           14.3    8  360.0  245  ...   0     3     4   97.866667
Merc 240D            24.4    4  146.7   62  ...   0     4     2   94.333333
Merc 230             22.8    4  140.8   95  ...   0     4     2   90.933333
Merc 280             19.2    6  167.6  123  ...   0     4     4   93.200000
Merc 280C            17.8    6  167.6  123  ...   0     4     4  102.266667
Merc 450SE           16.4    8  275.8  180  ...   0     3     3  107.666667
Merc 450SL           17.3    8  275.8  180  ...   0     3     3  112.600000
Merc 450SLC          15.2    8  275.8  180  ...   0     3     3  108.600000
Cadillac Fleetwood   10.4    8  472.0  205  ...   0     3     4  104.000000
Lincoln Continental  10.4    8  460.0  215  ...   0     3     4  103.666667
Chrysler Imperial    14.7    8  440.0  230  ...   0     3     4  105.000000
Fiat 128             32.4    4   78.7   66  ...   1     4     1  105.000000
Honda Civic          30.4    4   75.7   52  ...   1     4     2  104.466667
Toyota Corolla       33.9    4   71.1   65  ...   1     4     1   97.200000
Toyota Corona        21.5    4  120.1   97  ...   0     3     1  100.600000
Dodge Challenger     15.5    8  318.0  150  ...   0     3     2  101.466667
AMC Javelin          15.2    8  304.0  150  ...   0     3     2  109.333333
Camaro Z28           13.3    8  350.0  245  ...   0     3     4  111.800000
Pontiac Firebird     19.2    8  400.0  175  ...   0     3     2  106.533333
Fiat X1-9            27.3    4   79.0   66  ...   1     4     1  101.666667
Porsche 914-2        26.0    4  120.3   91  ...   1     5     2   95.800000
Lotus Europa         30.4    4   95.1  113  ...   1     5     2  101.466667
Ford Pantera L       15.8    8  351.0  264  ...   1     5     4  103.933333
Ferrari Dino         19.7    6  145.0  175  ...   1     5     6  107.000000
Maserati Bora        15.0    8  301.0  335  ...   1     5     8   97.400000
Volvo 142E           21.4    4  121.0  109  ...   1     4     2   96.400000

[32 rows x 12 columns]


要按照我的意願執行此操作,需要做幾件事。 raw=False如果僅調用.indexFalse: passes each row or column as a Series to the function. ),將提供對該系列的fun訪問。 這是愚蠢且低效的,但它確實有效。 我需要我的窗口center=True 我還需要NaN填充可用信息,所以我設置min_periods=0

這種方法有幾點我不喜歡:

  1. 在我看來,從fun范圍之外調用mtcars有潛在的危險,可能會導致錯誤。
  2. 逐行使用.loc的多重索引不能很好地擴展並且性能可能更差(滾動的次數比需要的多)

我不知道有一種方法可以通過將單個函數應用於 pandas 數據框來輕松高效地進行此計算,因為您正在計算多行和多列的值。 一種有效的方法是首先計算要為其計算滾動平均值的列,然后計算滾動平均值:

import statsmodels.api as sm
import pandas as pd
mtcars=pd.DataFrame(sm.datasets.get_rdataset("mtcars", "datasets", cache=True).data)

# Create column
def df_fun(df, col1, col2, w1=10, w2=2):
    return w1 * df[col1] + w2 * df[col2]
mtcars['fun_val'] = df_fun(mtcars, 'cyl', 'mpg')

# Calculate rolling average
mtcars['fun_val_r3m'] = mtcars['fun_val'].rolling(3, center=True, min_periods=0).mean()

這給出了正確的答案,並且是有效的,因為每個步驟都應該針對性能進行優化。 我發現像這樣分離行和列計算比你提出的最新方法快大約 10 倍,而且不需要導入 numpy。 如果您不想保留中間計算fun_val ,您可以用滾動平均值fun_val_r3m覆蓋它。

在此處輸入圖像描述

如果您真的需要在一行中使用apply來執行此操作,除了您在最新帖子中所做的之外,我不知道還有其他方法。 基於numpy數組的方法可能會執行得更好,但可讀性較差。

您可以使用以下功能進行滾動應用。 在某些情況下,與 pandas inbuild rolling 相比,它可能會很慢,但具有額外的功能。

函數參數 win_size、min_periods(類似於 pandas,只接受整數輸入)。 另外,參數也用於控制到窗口后,在觀察后將窗口移動到包含。

def roll_apply(df, fn, win_size, min_periods=None, after=None):

    if min_periods is None:
        min_periods = win_size
    else:
        assert min_periods >= 1
    
    if after is None:
        after = 0
    
    before = win_size - 1 - after
    i = np.arange(df.shape[0])
    s = np.maximum(i - before, 0)
    e = np.minimum(i + after, df.shape[0]) + 1
    
    res = [fn(df.iloc[si:ei]) for si, ei in zip(s, e) if (ei-si) >= min_periods]
    idx = df.index[(e-s) >= min_periods]

    types = {type(ri) for ri in res}
    if len(types) != 1:
        return pd.Series(res, index=idx)
    
    t = list(types)[0]
    if t == pd.Series:
        return pd.DataFrame(res, index=idx)
    elif t == pd.DataFrame:
        return pd.concat(res, keys=idx)
    else:
        return pd.Series(res, index=idx)
mtcars['roll'] = roll_apply(mtcars, lambda x: fun(x.cyl, x.mpg), win_size=3, min_periods=1, after=1)
指數 英里/加侖 圓柱體 顯示 生命值 廢話 重量 對比 齒輪 碳水化合物
馬自達RX4 21.0 6個 160.0 110 3.9 2.62 16.46 0 1個 4個 4個 102.0
馬自達 RX4 搖擺 21.0 6個 160.0 110 3.9 2.875 17.02 0 1個 4個 4個 96.53333333333335
達特桑710 22.8 4個 108.0 93 3.85 2.32 18.61 1個 1個 4個 1個 96.8
大黃蜂 4 驅動器 21.4 6個 258.0 110 3.08 3.215 19.44 1個 0 3個 1個 101.93333333333332
大黃蜂 Sportabout 18.7 8個 360.0 175 3.15 3.44 17.02 0 0 3個 2個 105.46666666666665
英勇 18.1 6個 225.0 105 2.76 3.46 20.22 1個 0 3個 1個 107.40000000000002
除塵器 360 14.3 8個 360.0 245 3.21 3.57 15.84 0 0 3個 4個 97.86666666666667
商務 240D 24.4 4個 146.7 62 3.69 3.19 20.0 1個 0 4個 2個 94.33333333333333
商業 230 22.8 4個 140.8 95 3.92 3.15 22.9 1個 0 4個 2個 90.93333333333332
商業 280 19.2 6個 167.6 123 3.92 3.44 18.3 1個 0 4個 4個 93.2
商務 280C 17.8 6個 167.6 123 3.92 3.44 18.9 1個 0 4個 4個 102.26666666666667
奔馳 450SE 16.4 8個 275.8 180 3.07 4.07 17.4 0 0 3個 3個 107.66666666666667
商務 450SL 17.3 8個 275.8 180 3.07 3.73 17.6 0 0 3個 3個 112.59999999999998
商務 450SLC 15.2 8個 275.8 180 3.07 3.78 18.0 0 0 3個 3個 108.60000000000001
凱迪拉克弗利特伍德 10.4 8個 472.0 205 2.93 5.25 17.98 0 0 3個 4個 104.0
林肯大陸 10.4 8個 460.0 215 3.0 5.424 17.82 0 0 3個 4個 103.66666666666667
克萊斯勒帝國 14.7 8個 440.0 230 3.23 5.345 17.42 0 0 3個 4個 105.0
菲亞特128 32.4 4個 78.7 66 4.08 2.2 19.47 1個 1個 4個 1個 105.0
本田思域 30.4 4個 75.7 52 4.93 1.615 18.52 1個 1個 4個 2個 104.46666666666665
豐田花冠 33.9 4個 71.1 65 4.22 1.835 19.9 1個 1個 4個 1個 97.2
豐田電暈 21.5 4個 120.1 97 3.7 2.465 20.01 1個 0 3個 1個 100.60000000000001
道奇挑戰者 15.5 8個 318.0 150 2.76 3.52 16.87 0 0 3個 2個 101.46666666666665
AMC標槍 15.2 8個 304.0 150 3.15 3.435 17.3 0 0 3個 2個 109.33333333333333
科邁羅Z28 13.3 8個 350.0 245 3.73 3.84 15.41 0 0 3個 4個 111.8
龐蒂亞克火鳥 19.2 8個 400.0 175 3.08 3.845 17.05 0 0 3個 2個 106.53333333333335
菲亞特X1-9 27.3 4個 79.0 66 4.08 1.935 18.9 1個 1個 4個 1個 101.66666666666667
保時捷 914-2 26.0 4個 120.3 91 4.43 2.14 16.7 0 1個 5個 2個 95.8
歐洲蓮花 30.4 4個 95.1 113 3.77 1.513 16.9 1個 1個 5個 2個 101.46666666666665
福特 Pantera L 15.8 8個 351.0 264 4.22 3.17 14.5 0 1個 5個 4個 103.93333333333332
法拉利迪諾 19.7 6個 145.0 175 3.62 2.77 15.5 0 1個 5個 6個 107.0
瑪莎拉蒂寶來 15.0 8個 301.0 335 3.54 3.57 14.6 0 1個 5個 8個 97.39999999999999
沃爾沃142E 21.4 4個 121.0 109 4.11 2.78 18.6 1個 1個 4個 2個 96.4

您可以在 roll_apply 函數中傳遞更復雜的函數。 下面是幾個例子

roll_apply(mtcars, lambda d: pd.Series({'A': d.sum().sum(), 'B': d.std().std()}), win_size=3, min_periods=1, after=1) # Simple example to illustrate use case

roll_apply(mtcars, lambda d: d, win_size=3, min_periods=3, after=1) # This will return rolling dataframe

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM