简体   繁体   中英

Find month to date and month to go on a Pyspark dataframe

I have the following dataframe in Spark (using PySpark):

DT_BORD_REF : Timestamp column,
COUNTRY_ALPHA : Country Alpha-3 code,
working_day_flag : if the date is a working day in that country or not

I need to add two fields:

  • count of working days from the beginning of the month for that country (month to date)
  • count of working days remaining until the end of that month for that country (month to go)

It seems it's an application of a window function, but I can't figure out

+-------------------+-------------+----------------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|
+-------------------+-------------+----------------+
|2021-01-01 00:00:00|          FRA|               N|
|2021-01-01 00:00:00|          ITA|               N|
|2021-01-01 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          BRA|               N|
|2021-01-02 00:00:00|          FRA|               N|
|2021-01-02 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          ITA|               N|
|2021-01-03 00:00:00|          BRA|               N|
|2021-01-03 00:00:00|          FRA|               N|
|2021-01-04 00:00:00|          BRA|               Y|
|2021-01-04 00:00:00|          FRA|               Y|
|2021-01-04 00:00:00|          ITA|               Y|
|2021-01-05 00:00:00|          FRA|               Y|
|2021-01-05 00:00:00|          BRA|               Y|
|2021-01-05 00:00:00|          ITA|               Y|
|2021-01-06 00:00:00|          ITA|               N|
|2021-01-06 00:00:00|          FRA|               Y|
|2021-01-06 00:00:00|          BRA|               Y|
|2021-01-07 00:00:00|          ITA|               Y|
+-------------------+-------------+----------------+

Use a running sum over Window function. To limit the window to a month and a country, use partition by COUNTRY_ALPHA and DATE_TRUNC(DT_BORD_REF, 'MONTH') . Then using rows between unbounded preceding and current row you can get the sum of worked days until the current date. The same logic applies to get the remaining days in the month by using rows between 1 following and unbounded following.

To filter only days with working_day_flag = 'Y' , use conditional sum with case/when .

Here's a working example with the sample data you provided in your question:

df.createOrReplaceTempView("df")

sql_query = """
SELECT
  *,
  SUM(CASE
    WHEN BOOLEAN(working_day_flag) THEN 1
    ELSE 0
  END) OVER (
  PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
  ORDER BY DT_BORD_REF ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
  ) AS month_to_date,

  COALESCE(SUM(CASE
    WHEN BOOLEAN(working_day_flag) THEN 1
    ELSE 0
  END) OVER (
  PARTITION BY COUNTRY_ALPHA, DATE_TRUNC('MONTH', DT_BORD_REF)
  ORDER BY DT_BORD_REF ROWS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING
  ), 0) AS month_to_go

FROM df
""" 

spark.sql(sql_query).show()

#+-------------------+-------------+----------------+-------------+-----------+
#|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
#+-------------------+-------------+----------------+-------------+-----------+
#|2021-01-01 00:00:00|          BRA|               N|            0|          3|
#|2021-01-02 00:00:00|          BRA|               N|            0|          3|
#|2021-01-03 00:00:00|          BRA|               N|            0|          3|
#|2021-01-04 00:00:00|          BRA|               Y|            1|          2|
#|2021-01-05 00:00:00|          BRA|               Y|            2|          1|
#|2021-01-06 00:00:00|          BRA|               Y|            3|          0|
#|2021-01-01 00:00:00|          FRA|               N|            0|          3|
#|2021-01-02 00:00:00|          FRA|               N|            0|          3|
#|2021-01-03 00:00:00|          FRA|               N|            0|          3|
#|2021-01-04 00:00:00|          FRA|               Y|            1|          2|
#|2021-01-05 00:00:00|          FRA|               Y|            2|          1|
#|2021-01-06 00:00:00|          FRA|               Y|            3|          0|
#|2021-01-01 00:00:00|          ITA|               N|            0|          3|
#|2021-01-02 00:00:00|          ITA|               N|            0|          3|
#|2021-01-03 00:00:00|          ITA|               N|            0|          3|
#|2021-01-04 00:00:00|          ITA|               Y|            1|          2|
#|2021-01-05 00:00:00|          ITA|               Y|            2|          1|
#|2021-01-06 00:00:00|          ITA|               N|            2|          1|
#|2021-01-07 00:00:00|          ITA|               Y|            3|          0|

You can do a conditional count using count_if :

df.createOrReplaceTempView('df')

result = spark.sql("""
select *,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref)
        month_to_date,
    count_if(working_day_flag = 'Y')
        over(partition by country_alpha, trunc(dt_bord_ref, 'month') order by dt_bord_ref
             rows between 1 following and unbounded following)
        month_to_go    
from df
""")

result.show()
+-------------------+-------------+----------------+-------------+-----------+
|        DT_BORD_REF|COUNTRY_ALPHA|working_day_flag|month_to_date|month_to_go|
+-------------------+-------------+----------------+-------------+-----------+
|2021-01-01 00:00:00|          BRA|               N|            0|          3|
|2021-01-02 00:00:00|          BRA|               N|            0|          3|
|2021-01-03 00:00:00|          BRA|               N|            0|          3|
|2021-01-04 00:00:00|          BRA|               Y|            1|          2|
|2021-01-05 00:00:00|          BRA|               Y|            2|          1|
|2021-01-06 00:00:00|          BRA|               Y|            3|          0|
|2021-01-01 00:00:00|          ITA|               N|            0|          3|
|2021-01-02 00:00:00|          ITA|               N|            0|          3|
|2021-01-03 00:00:00|          ITA|               N|            0|          3|
|2021-01-04 00:00:00|          ITA|               Y|            1|          2|
|2021-01-05 00:00:00|          ITA|               Y|            2|          1|
|2021-01-06 00:00:00|          ITA|               N|            2|          1|
|2021-01-07 00:00:00|          ITA|               Y|            3|          0|
|2021-01-01 00:00:00|          FRA|               N|            0|          3|
|2021-01-02 00:00:00|          FRA|               N|            0|          3|
|2021-01-03 00:00:00|          FRA|               N|            0|          3|
|2021-01-04 00:00:00|          FRA|               Y|            1|          2|
|2021-01-05 00:00:00|          FRA|               Y|            2|          1|
|2021-01-06 00:00:00|          FRA|               Y|            3|          0|
+-------------------+-------------+----------------+-------------+-----------+

If you want a similar solution in Pyspark API:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

result = df.withColumn(
    'month_to_date',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
    )
).withColumn(
    'month_to_go',
    F.count(
        F.when(F.col('working_day_flag') == 'Y', 1)
    ).over(
        Window.partitionBy('country_alpha', F.trunc('dt_bord_ref', 'month'))
              .orderBy('dt_bord_ref')
              .rowsBetween(1, Window.unboundedFollowing)
    )
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM