![](/img/trans.png)
[英]Pyspark - DataFrame not updated when applying functions in a loop
[英]PySpark: applying varying window sizes to a dataframe in pyspark
我有一個火花 dataframe,看起來像下面這樣。
日期 | ID | 窗口大小 | 數量 |
---|---|---|---|
2020 年 1 月 1 日 | 1 | 2 | 1 |
2020 年 2 月 1 日 | 1 | 2 | 2 |
2020 年 3 月 1 日 | 1 | 2 | 3 |
2020 年 4 月 1 日 | 1 | 2 | 4 |
2020 年 1 月 1 日 | 2 | 3 | 1 |
2020 年 2 月 1 日 | 2 | 3 | 2 |
2020 年 3 月 1 日 | 2 | 3 | 3 |
2020 年 4 月 1 日 | 2 | 3 | 4 |
我正在嘗試將大小為 window_size 的滾動 window 應用於 dataframe 中的每個 ID 並獲得滾動總和。 基本上我正在計算一個滾動總和( pd.groupby.rolling(window=n).sum()
in pandas),其中 window 大小(n)可以按組更改。
預期 output
日期 | ID | 窗口大小 | 數量 | 滾動總和 |
---|---|---|---|---|
2020 年 1 月 1 日 | 1 | 2 | 1 | null |
2020 年 2 月 1 日 | 1 | 2 | 2 | 3 |
2020 年 3 月 1 日 | 1 | 2 | 3 | 5 |
2020 年 4 月 1 日 | 1 | 2 | 4 | 7 |
2020 年 1 月 1 日 | 2 | 3 | 1 | null |
2020 年 2 月 1 日 | 2 | 3 | 2 | null |
2020 年 3 月 1 日 | 2 | 3 | 3 | 6 |
2020 年 4 月 1 日 | 2 | 3 | 4 | 9 |
我正在努力尋找一個在大型 dataframe(+- 350M 行)上有效且足夠快的解決方案。
我試過的
我在下面的線程中嘗試了解決方案:
這個想法是首先使用sf.collect_list
然后正確切片ArrayType
列。
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
但是,這會產生以下錯誤:
TypeError:列不可迭代
我也嘗試過使用sf.expr
,如下所示
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
產生:
數據類型不匹配:參數 1 需要數組類型,但 ''qty_list'' 是字符串類型。 第 1 行第 0 行;
我嘗試將qty_list
列手動轉換為ArrayType(IntegerType())
,結果相同。
我嘗試使用 UDF,但在 1.5 小時左右后出現幾個 memory 錯誤失敗。
問題
閱讀 spark 文檔向我表明我應該能夠將列傳sf.slice()
,我做錯了嗎? TypeError
來自哪里?
有沒有更好的方法來實現我想要的而不使用sf.collect_list()
和/或sf.slice()
?
如果所有其他方法都失敗了,那么使用 udf 執行此操作的最佳方法是什么? 我嘗試了同一個 udf 的不同版本,並試圖確保 udf 是 spark 必須執行的最后一個操作,但都失敗了。
關於你得到的錯誤:
slice
(除非您有 Spark 3.x)。 但是當您嘗試在 SQL 表達式中使用它時,您已經得到了它。expr
中引用的列名。 它應該是slice(qty_list, count, window_size)
否則 Spark 將它們視為字符串,因此會出現錯誤消息。 也就是說,你幾乎明白了,你需要更改切片的表達式,然后使用aggregate
function 對結果數組的值求和。 其他部分實際上與您的嘗試相同:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
).drop("qty_list", "rn")
output.show()
#+----------+---+-----------+---+-----------+
#| date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
#+----------+---+-----------+---+-----------+
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.