繁体   English   中英

加快熊猫滚动窗口

[英]Speed up rolling window in Pandas

我有此代码,可以正常工作,并给我我想要的结果。 它循环显示窗口大小的列表,以为sum_metric_list,min_metric_list和max_metric_list中的每个指标创建滚动聚合。

# create the rolling aggregations for each window
for window in constants.AGGREGATION_WINDOW:
    # get the sum and count sums
    sum_metrics_names_list = [x[6:] + "_1_" + str(window) for x in sum_metrics_list]
    adt_df[sum_metrics_names_list] = adt_df.groupby('athlete_id')[sum_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).sum())

    # get the min of mins
    min_metrics_names_list = [x[6:] + "_1_" + str(window) for x in min_metrics_list]
    adt_df[min_metrics_names_list] = adt_df.groupby('athlete_id')[min_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).min())

    # get the max of max
    max_metrics_names_list = [x[6:] + "_1_" + str(window) for x in max_metrics_list]
    adt_df[max_metrics_names_list] = adt_df.groupby('athlete_id')[max_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).max())

它在小型数据集上运行良好,但是一旦我对具有> 3000个指标和40个窗口的完整数据运行它,它就会变得非常慢。 有什么方法可以优化此代码?

下面的基准测试(和代码)建议您可以使用以下方法节省大量时间

df.groupby(...).rolling() 

代替

df.groupby(...)[col].apply(lambda x: x.rolling(...))

此处节省时间的主要思想是尝试一次(一次调用)将向量化函数(例如sum )应用于最大可能的数组(或DataFrame),而不是许多微小的函数调用。

df.groupby(...).rolling().sum()在每个(分组的)子DataFrame上调用sum 一次调用即可计算所有列的滚动总和。 您可以使用df[sum_metrics_list+[key]].groupby(key).rolling().sum()来计算sum_metrics_list列上的滚动/总和。

相反, df.groupby(...)[col].apply(lambda x: x.rolling(...))在每个(分组的)子DataFrame的单个列上调用sum 由于您有> 3000个指标,因此最终会调用df.groupby(...)[col].rolling().sum() (或minmax )3000次。

当然,这种对调用次数进行计数的伪逻辑只是一种试探法, 可以指导您朝着更快的代码的方向发展。 证明在布丁中:

import collections
import timeit 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def make_df(nrows=100, ncols=3):
    seed = 2018
    np.random.seed(seed)
    df = pd.DataFrame(np.random.randint(10, size=(nrows, ncols)))
    df['athlete_id'] = np.random.randint(10, size=nrows)
    return df

def orig(df, key='athlete_id'):
    columns = list(df.columns.difference([key]))
    result = pd.DataFrame(index=df.index)
    for window in range(2, 4):
        for col in columns:
            colname = 'sum_col{}_winsize{}'.format(col, window)
            result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
                center=False, window=window, min_periods=1).sum())
            colname = 'min_col{}_winsize{}'.format(col, window)
            result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
                center=False, window=window, min_periods=1).min())
            colname = 'max_col{}_winsize{}'.format(col, window)
            result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
                center=False, window=window, min_periods=1).max())
    result = pd.concat([df, result], axis=1)
    return result

def alt(df, key='athlete_id'):
    """
    Call rolling on the whole DataFrame, not each column separately
    """
    columns = list(df.columns.difference([key]))
    result = [df]
    for window in range(2, 4):
        rolled = df.groupby(key, group_keys=False).rolling(
            center=False, window=window, min_periods=1)

        new_df = rolled.sum().drop(key, axis=1)
        new_df.columns = ['sum_col{}_winsize{}'.format(col, window) for col in columns]
        result.append(new_df)

        new_df = rolled.min().drop(key, axis=1)
        new_df.columns = ['min_col{}_winsize{}'.format(col, window) for col in columns]
        result.append(new_df)

        new_df = rolled.max().drop(key, axis=1)
        new_df.columns = ['max_col{}_winsize{}'.format(col, window) for col in columns]
        result.append(new_df)

    df = pd.concat(result, axis=1)
    return df

timing = collections.defaultdict(list)
ncols = [3, 10, 20, 50, 100]
for n in ncols:
    df = make_df(ncols=n)
    timing['orig'].append(timeit.timeit(
        'orig(df)',
        'from __main__ import orig, alt, df',
        number=10))
    timing['alt'].append(timeit.timeit(
        'alt(df)',
        'from __main__ import orig, alt, df',
        number=10))

plt.plot(ncols, timing['orig'], label='using groupby/apply (orig)')
plt.plot(ncols, timing['alt'], label='using groupby/rolling (alternative)')
plt.legend(loc='best')
plt.xlabel('number of columns')
plt.ylabel('seconds')
print(pd.DataFrame(timing, index=pd.Series(ncols, name='ncols')))
plt.show()

在此处输入图片说明 而产生这些timeit基准

            alt       orig
ncols                     
3      0.871695   0.996862
10     0.991617   3.307021
20     1.168522   6.602289
50     1.676441  16.558673
100    2.521121  33.261957

orig相比, alt的速度优势似乎随着列数的增加而增加。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM