简体   繁体   中英

Rolling Sum of a column based on another column in a DataFrame

I have a DataFrame that looks like below

 ID      Date      Amount   

10001   2019-07-01   50     
10001   2019-05-01   15
10001   2019-06-25   10   
10001   2019-05-27   20
10002   2019-06-29   25
10002   2019-07-18   35
10002   2019-07-15   40

From the amount column, I'm trying to get a 4 week rolling sum based on the date column. What I mean by that is, basically I need one more column (say amount_4wk_rolling) that will have a sum of amount column for all the rows that go back 4 weeks. So if the date in the row is 2019-07-01, then the amount_4wk_rolling column value should be the sum of amount of all the rows whose date is between 2019-07-01 and 2019-06-04 (2019-07-01 minus 28 days). So the the new DataFrame would look something like this.

 ID        Date      Amount  amount_4wk_rolling
10001   2019-07-01    50       60
10001   2019-05-01    15       15
10001   2019-06-25    10       30
10001   2019-05-27    20       35
10002   2019-06-29    25       25
10002   2019-07-18    35       100
10002   2019-07-15    40       65

I have tried using window functions except it doesn't let me choose a window based on the value of a particular column

Edit:
 My data is huge...about a TB in size. Ideally, I would like to do this in spark rather that in pandas 

as suggested, you can use .rolling on Date with "28d".

seems (from your example values) that you also wanted the rolling window grouped by ID.

try this:

import pandas as pd
from io import StringIO

s = """
 ID      Date      Amount   

10001   2019-07-01   50     
10001   2019-05-01   15
10001   2019-06-25   10   
10001   2019-05-27   20
10002   2019-06-29   25
10002   2019-07-18   35
10002   2019-07-15   40
"""

df = pd.read_csv(StringIO(s), sep="\s+")
df['Date'] = pd.to_datetime(df['Date'])
amounts = df.groupby(["ID"]).apply(lambda g: g.sort_values('Date').rolling('28d', on='Date').sum())
df['amount_4wk_rolling'] = df["Date"].map(amounts.set_index('Date')['Amount'])
print(df)

Output:

      ID       Date  Amount  amount_4wk_rolling
0  10001 2019-07-01      50                60.0
1  10001 2019-05-01      15                15.0
2  10001 2019-06-25      10                10.0
3  10001 2019-05-27      20                35.0
4  10002 2019-06-29      25                25.0
5  10002 2019-07-18      35               100.0
6  10002 2019-07-15      40                65.0

I believe pandas rolling methods are based on the index. So performing:

df.index = df['Date']

and then performing the rolling method, specified by your time range, may do the trick.

See also the documentation (specifically the ones at the bottom of the document): https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html

Edit: You can also use the argument on='Date' as pointed out in the comments, no need for re-indexing.

This can be done with pandas_udf , and it looks like you want to group with 'ID', so I used it as group id.

spark = SparkSession.builder.appName('test').getOrCreate()
df = spark.createDataFrame([Row(ID=10001, d='2019-07-01', Amount=50),
                            Row(ID=10001, d='2019-05-01', Amount=15),
                            Row(ID=10001, d='2019-06-25', Amount=10),
                            Row(ID=10001, d='2019-05-27', Amount=20),
                            Row(ID=10002, d='2019-06-29', Amount=25),
                            Row(ID=10002, d='2019-07-18', Amount=35),
                            Row(ID=10002, d='2019-07-15', Amount=40)
                           ])
df = df.withColumn('date', F.to_date('d', 'yyyy-MM-dd'))
df = df.withColumn('prev_date', F.date_sub(df['date'], 28))
df.select(["ID", "prev_date", "date", "Amount"]).orderBy('date').show()
df = df.withColumn('amount_4wk_rolling', F.lit(0.0))
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def roll_udf(pdf):
    for index, row in pdf.iterrows():
        d, pd = row['date'], row['prev_date']
        pdf.loc[pdf['date']==d, 'amount_4wk_rolling'] = np.sum(pdf.loc[(pdf['date']<=d)&(pdf['date']>=pd)]['Amount'])
    return pdf

df = df.groupby('ID').apply(roll_udf)
df.select(['ID', 'date', 'prev_date', 'Amount', 'amount_4wk_rolling']).orderBy(['ID', 'date']).show()

The output:

+-----+----------+----------+------+
|   ID| prev_date|      date|Amount|
+-----+----------+----------+------+
|10001|2019-04-03|2019-05-01|    15|
|10001|2019-04-29|2019-05-27|    20|
|10001|2019-05-28|2019-06-25|    10|
|10002|2019-06-01|2019-06-29|    25|
|10001|2019-06-03|2019-07-01|    50|
|10002|2019-06-17|2019-07-15|    40|
|10002|2019-06-20|2019-07-18|    35|
+-----+----------+----------+------+

+-----+----------+----------+------+------------------+
|   ID|      date| prev_date|Amount|amount_4wk_rolling|
+-----+----------+----------+------+------------------+
|10001|2019-05-01|2019-04-03|    15|              15.0|
|10001|2019-05-27|2019-04-29|    20|              35.0|
|10001|2019-06-25|2019-05-28|    10|              10.0|
|10001|2019-07-01|2019-06-03|    50|              60.0|
|10002|2019-06-29|2019-06-01|    25|              25.0|
|10002|2019-07-15|2019-06-17|    40|              65.0|
|10002|2019-07-18|2019-06-20|    35|             100.0|
+-----+----------+----------+------+------------------+

For pyspark, you can just use Window function: sum + RangeBetween

from pyspark.sql import functions as F, Window

# skip code to initialize Spark session and dataframe

>>> df.show()
+-----+----------+------+
|   ID|      Date|Amount|
+-----+----------+------+
|10001|2019-07-01|    50|
|10001|2019-05-01|    15|
|10001|2019-06-25|    10|
|10001|2019-05-27|    20|
|10002|2019-06-29|    25|
|10002|2019-07-18|    35|
|10002|2019-07-15|    40|
+-----+----------+------+

>>> df.printSchema()
root
 |-- ID: long (nullable = true)
 |-- Date: string (nullable = true)
 |-- Amount: long (nullable = true)

win = Window.partitionBy('ID').orderBy(F.to_timestamp('Date').astype('long')).rangeBetween(-28*86400,0)

df_new = df.withColumn('amount_4wk_rolling', F.sum('Amount').over(win))

>>> df_new.show()
+------+-----+----------+------------------+
|Amount|   ID|      Date|amount_4wk_rolling|
+------+-----+----------+------------------+
|    25|10002|2019-06-29|                25|
|    40|10002|2019-07-15|                65|
|    35|10002|2019-07-18|               100|
|    15|10001|2019-05-01|                15|
|    20|10001|2019-05-27|                35|
|    10|10001|2019-06-25|                10|
|    50|10001|2019-07-01|                60|
+------+-----+----------+------------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM