簡體   English   中英


[英]pyspark lag function on one column based on the value in another column


在給定的數據中 Qdf 是問題數據框和 Adf 答案數據框。 我已經給出了一個額外的解釋列(在我的最終數據中實際上不需要)

from pyspark.sql.window import Window
import pyspark.sql.functions as func
from pyspark.sql.types import *
from pyspark.sql import SQLContext

ID = ['A' for i in range(0,10)]+ ['B' for i in range(0,10)]
Day = range(1,11)+range(1,11)
Delay = [2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 3]
Despatched = [2, 3, 1, 4, 6, 2, 6, 5, 3, 6, 3, 1, 2, 4, 1, 2, 3, 3, 6, 1]
Delivered = [0, 0, 2, 3, 1, 0, 10, 0, 0, 13, 0, 0, 3, 1, 0, 6, 0, 0, 6, 3]
Explanation = ["-", "-", "-", "-", "-", "-", "10 (4+6)", "-", "-", "13 (2+6+5)", "-", "-", "-", "-", "-", "6 (2+4)", "-", "-", "6 (1+2+3)", "-"]

QSchema = StructType([StructField("ID", StringType()),StructField("Day", IntegerType()),StructField("Delay", IntegerType()),StructField("Despatched", IntegerType())])
Qdata = map(list, zip(*[ID,Day,Delay,Despatched]))
Qdf = spark.createDataFrame(Qdata,schema=QSchema) 

| ID|Day|Delay|Despatched|
|  A|  1|    2|         2|
|  A|  2|    2|         3|
|  A|  3|    2|         1|
|  A|  4|    3|         4|
|  A|  5|    2|         6|
|  A|  6|    4|         2|
|  A|  7|    3|         6|
|  A|  8|    2|         5|
|  A|  9|    2|         3|
|  A| 10|    2|         6|
|  B|  1|    2|         3|
|  B|  2|    2|         1|
|  B|  3|    3|         2|
|  B|  4|    2|         4|
|  B|  5|    4|         1|
|  B|  6|    3|         2|
|  B|  7|    2|         3|
|  B|  8|    2|         3|
|  B|  9|    2|         6|
|  B| 10|    3|         1|

發貨數量應在延遲時間后記錄為已發貨。 理想情況下,如果我可以根據lag function在調度的列上應用lag function ,那就太好了。 答案數據集如下所示:

Adata = map(list, zip(*[ID,Day,Delay,Despatched,Delivered,Explanation]))
ASchema = StructType([StructField("ID", StringType()),StructField("Day", IntegerType()),StructField("Delay", IntegerType()),StructField("Despatched", IntegerType()),StructField("Delivered", IntegerType()),StructField("Explanation", StringType())])
Adf = spark.createDataFrame(Adata,schema=ASchema) 

| ID|Day|Delay|Despatched|Delivered|Explanation|
|  A|  1|    2|         2|        0|          -|
|  A|  2|    2|         3|        0|          -|
|  A|  3|    2|         1|        2|          -|
|  A|  4|    3|         4|        3|          -|
|  A|  5|    2|         6|        1|          -|
|  A|  6|    4|         2|        0|          -|
|  A|  7|    3|         6|       10|   10 (4+6)|
|  A|  8|    2|         5|        0|          -|
|  A|  9|    2|         3|        0|          -|
|  A| 10|    2|         6|       13| 13 (2+6+5)|
|  B|  1|    2|         3|        0|          -|
|  B|  2|    2|         1|        0|          -|
|  B|  3|    3|         2|        3|          -|
|  B|  4|    2|         4|        1|          -|
|  B|  5|    4|         1|        0|          -|
|  B|  6|    3|         2|        6|    6 (2+4)|
|  B|  7|    2|         3|        0|          -|
|  B|  8|    2|         3|        0|          -|
|  B|  9|    2|         6|        6|  6 (1+2+3)|
|  B| 10|    3|         1|        3|          -|

我已經嘗試了下面的代碼來獲得 2 的恆定滯后:





我怎樣才能克服這個? 我使用的是 PySpark 2.3.1 版和 python 2.7.13 版。


from pyspark.sql.window import Window
import pyspark.sql.functions as F
import pyspark.sql.types as T 

ID = ['A' for i in range(0,10)]+ ['B' for i in range(0,10)]
#I had to modify this line as I'am working with python3
Day = list(range(1,11))+list(range(1,11))
Delay = [2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 3]
Despatched = [2, 3, 1, 4, 6, 2, 6, 5, 3, 6, 3, 1, 2, 4, 1, 2, 3, 3, 6, 1]
Delivered = [0, 0, 2, 3, 1, 0, 10, 0, 0, 13, 0, 0, 3, 1, 0, 6, 0, 0, 6, 3]
Explanation = ["-", "-", "-", "-", "-", "-", "10 (4+6)", "-", "-", "13 (2+6+5)", "-", "-", "-", "-", "-", "6 (2+4)", "-", "-", "6 (1+2+3)", "-"]

QSchema = T.StructType([T.StructField("ID", T.StringType()),T.StructField("Day", T.IntegerType()),T.StructField("Delay", T.IntegerType()),T.StructField("Despatched", T.IntegerType())])
Qdata = map(list, zip(*[ID,Day,Delay,Despatched]))
Qdf = spark.createDataFrame(Qdata,schema=QSchema) 
#until here it was basically your code

#At first we add an empty Delivered_lag column to the Qdf
#That allows us to use the same functionality for all iterations of the following loop
Qdf = Qdf.withColumn('Delivered_lag',  F.lit(None).cast(T.IntegerType()))

#Now we loop over the distinctive values of Qdf.delay and run the lag function for every value
#otherwise is necessary to keep the previous calculated values 
for delay in Qdf.select('delay').distinct().collect():
    Qdf = Qdf.withColumn('Delivered_lag', F.when(Qdf['Delay'] == delay.delay, F.lag(Qdf['Despatched'],delay.delay).over(Window.partitionBy("ID").orderBy("Day"))).otherwise(Qdf['Delivered_lag']))



| ID|Day|Delay|Despatched|Delivered_lag|
|  B|  1|    2|         3|         null|
|  B|  2|    2|         1|         null|
|  B|  3|    3|         2|         null| 
|  B|  4|    2|         4|            1| 
|  B|  5|    4|         1|            3| 
|  B|  6|    3|         2|            2| 
|  B|  7|    2|         3|            1| 
|  B|  8|    2|         3|            2| 
|  B|  9|    2|         6|            3| 
|  B| 10|    3|         1|            3| 
|  A|  1|    2|         2|         null| 
|  A|  2|    2|         3|         null| 
|  A|  3|    2|         1|            2| 
|  A|  4|    3|         4|            2| 
|  A|  5|    2|         6|            1| 
|  A|  6|    4|         2|            3| 
|  A|  7|    3|         6|            4| 
|  A|  8|    2|         5|            2| 
|  A|  9|    2|         3|            6| 
|  A| 10|    2|         6|            5| 


聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

粵ICP備18138465號  © 2020-2024 STACKOOM.COM