简体   繁体   中英

Pyspark advanced window function

Here is my dataframe :

FlightDate=[20,40,51,50,60,15,17,37,36,50]
IssuingDate=[10,15,44,45,55,10,2,30,32,24]
Revenue = [100,50,40,70,60,40,30,100,200,100]
Customer = ['a','a','a','a','a','b','b','b','b','b']
df = spark.createDataFrame(pd.DataFrame([Customer,FlightDate,IssuingDate, Revenue]).T, schema=["Customer",'FlightDate', 'IssuingDate','Revenue'])
df.show()

+--------+----------+-----------+-------+
|Customer|FlightDate|IssuingDate|Revenue|
+--------+----------+-----------+-------+
|       a|        20|         10|    100|
|       a|        40|         15|     50|
|       a|        51|         44|     40|
|       a|        50|         45|     70|
|       a|        60|         55|     60|
|       b|        15|         10|     40|
|       b|        27|          2|     30|
|       b|        37|         30|    100|
|       b|        36|         32|    200|
|       b|        50|         24|    100|
+--------+----------+-----------+-------+

For convenience, I used number for days.

For each customer, I would like to sum revenues for all issuing dates between studied FlightDate and studied FlightDate + 10 days.

That is to say :

  • For the first line : I sum all revenue for IssuingDate between day 20 and day 30... which gives 0 here.
  • For the second line : I sum all revenus for IssuingDate between day 40 and 50, that is to say 40+70 = 110

Here is the desired result :

+--------+----------+-----------+-------+------+
|Customer|FlightDate|IssuingDate|Revenue|Result|
+--------+----------+-----------+-------+------+
|       a|        20|         10|    100|     0|
|       a|        40|         15|     50|   110|
|       a|        51|         44|     40|    60|
|       a|        50|         45|     70|    60|
|       a|        60|         55|     60|     0|
|       b|        15|         10|     40|   100|
|       b|        27|          2|     30|   300|
|       b|        37|         30|    100|     0|
|       b|        36|         32|    200|     0|
|       b|        50|         24|    100|     0|
+--------+----------+-----------+-------+------+

I know it will involve some window functions but this one seems a bit tricky. Thanks

no need of a window function. It is just a join and an agg :

df.alias("df").join(
    df.alias("df_2"),
    on=F.expr(
        "df.Customer = df_2.Customer "
        "and df_2.issuingdate between df.flightdate and df.flightdate+10"
    ), 
    how='left'
).groupBy(
    *('df.{}'.format(c) 
      for c 
      in df.columns)
).agg(
    F.sum(F.coalesce(
        "df_2.revenue", 
        F.lit(0))
    ).alias("result")
).show()

+--------+----------+-----------+-------+------+                                
|Customer|FlightDate|IssuingDate|Revenue|result|
+--------+----------+-----------+-------+------+
|       a|        20|         10|    100|     0|
|       a|        40|         15|     50|   110|
|       a|        50|         45|     70|    60|
|       a|        51|         44|     40|    60|
|       a|        60|         55|     60|     0|
|       b|        15|         10|     40|   100|
|       b|        27|          2|     30|   300|
|       b|        36|         32|    200|     0|
|       b|        37|         30|    100|     0|
|       b|        50|         24|    100|     0|
+--------+----------+-----------+-------+------+

If you would like to keep the Revenue for the current row and next 10 days then you can use below code.

For eg

First line: flightDate = 20 and you need revenue between 20 and 30 (both dates inclusive) which means Total Revenue = 100.

Second Line: flightDate = 40 and you need revenue between 40 and 50 (both dates inclusive) which means Total revenue = 50 (for date 40) + 50 (for date 50) = 120.

Third Line: flightDate = 50 and you need revenue between 50 and 60 (both dates inclusive) which mean Total revenue = 70(for date 50) + 40(for date 51) + 60(for date 60) = 170

from pyspark.sql import *
from pyspark.sql.functions import *
import pandas as pd

FlightDate=[20,40,51,50,60,15,17,37,36,50]
IssuingDate=[10,15,44,45,55,10,2,30,32,24]
Revenue = [100,50,40,70,60,40,30,100,200,100]
Customer = ['a','a','a','a','a','b','b','b','b','b']
df = spark.createDataFrame(pd.DataFrame([Customer,FlightDate,IssuingDate, Revenue]).T, schema=["Customer",'FlightDate', 'IssuingDate','Revenue'])


windowSpec = Window.partitionBy("Customer").orderBy("FlightDate").rangeBetween(0,10)
df.withColumn("Sum", sum("Revenue").over(windowSpec)).sort("Customer").show()

Result as mentioned below

+--------+----------+-----------+-------+---+
|Customer|FlightDate|IssuingDate|Revenue|Sum|
+--------+----------+-----------+-------+---+
|       a|        20|         10|    100|100|
|       a|        40|         15|     50|120|
|       a|        50|         45|     70|170|
|       a|        51|         44|     40|100|
|       a|        60|         55|     60| 60|
|       b|        15|         10|     40| 70|
|       b|        17|          2|     30| 30|
|       b|        36|         32|    200|300|
|       b|        37|         30|    100|100|
|       b|        50|         24|    100|100|
+--------+----------+-----------+-------+---+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM