[英]pyspark window function calculation issue with avg method
我有一个输入 dataframe 如下:
partner_id|month_id|value1 |value2
1001 | 01 |10 |20
1002 | 01 |20 |30
1003 | 01 |30 |40
1001 | 02 |40 |50
1002 | 02 |50 |60
1003 | 02 |60 |70
1001 | 03 |70 |80
1002 | 03 |80 |90
1003 | 03 |90 |100
使用下面的代码,我创建了两个新列,它们使用 window function 进行平均:
rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn("value1_1", F.avg("value1").over(rnum))
df = df.withColumn("value1_2", F.avg("value2").over(rnum))
Output:
partner_id|month_id|value1 |value2|value1_1|value2_2
1001 | 01 |10 |20 |10 |20
1002 | 01 |20 |30 |20 |30
1003 | 01 |30 |40 |30 |40
1001 | 02 |40 |50 |25 |35
1002 | 02 |50 |60 |35 |45
1003 | 02 |60 |70 |45 |55
1001 | 03 |70 |80 |40 |50
1002 | 03 |80 |90 |50 |60
1003 | 03 |90 |100 |60 |70
使用 pyspark Window function,累积平均值在 value1 和 value2 列上表现良好。 但是,如果我们在下面的输入中错过了一个月的数据,那么下个月的平均计算应该基于月份。 而不是正常的平均值。 例如,如果输入如下所示(缺少 02 月数据)
partner_id|month_id|value1 |value2
1001 | 01 |10 |20
1002 | 01 |20 |30
1003 | 01 |30 |40
1001 | 03 |70 |80
1002 | 03 |80 |90
1003 | 03 |90 |100
然后对第三个月记录的平均计算发生如下:例如:(70 + 10)/2 但是,如果缺少某些月份值,正确的平均方法是什么?
如果您使用的是火花 2.4+。 您可以使用序列 function 和数组函数。 此解决方案受此链接启发
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w= Window().partitionBy("partner_id")
df1 =df.withColumn("month_seq", F.sequence(F.min("month_id").over(w), F.max("month_id").over(w), F.lit(1)))\
.groupBy("partner_id").agg(F.collect_list("month_id").alias("month_id"), F.collect_list("value1").alias("value1"), F.collect_list("value2").alias("value2")
,F.first("month_seq").alias("month_seq")).withColumn("month_seq", F.array_except("month_seq","month_id"))\
.withColumn("month_id",F.flatten(F.array("month_id","month_seq"))).drop("month_seq")\
.withColumn("zip", F.explode(F.arrays_zip("month_id","value1", "value2"))) \
.select("partner_id", "zip.month_id", F.when(F.col("zip.value1").isNull() , \
F.lit(0)).otherwise(F.col("zip.value1")).alias("value1"),
F.when(F.col("zip.value2").isNull(), F.lit(0)).otherwise(F.col("zip.value2")
).alias("value2")).orderBy("month_id")
rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df2 = df1.withColumn("value1_1", F.avg("value1").over(rnum)).withColumn("value1_2", F.avg("value2").over(rnum))
df2.show()
# +----------+--------+------+------+------------------+------------------+
# |partner_id|month_id|value1|value2| value1_1| value1_2|
# +----------+--------+------+------+------------------+------------------+
# | 1002| 1| 10| 20| 10.0| 20.0|
# | 1002| 2| 0| 0| 5.0| 10.0|
# | 1002| 3| 80| 90| 30.0|36.666666666666664|
# | 1001| 1| 10| 10| 10.0| 10.0|
# | 1001| 2| 0| 0| 5.0| 5.0|
# | 1001| 3| 70| 80|26.666666666666668| 30.0|
# | 1003| 1| 30| 40| 30.0| 40.0|
# | 1003| 2| 0| 0| 15.0| 20.0|
# | 1003| 3| 90| 100| 40.0|46.666666666666664|
# +----------+--------+------+------+------------------+------------------+
Spark 不够聪明,无法理解缺少一个月,因为它甚至不知道一个月可能是什么。
如果您希望“缺失”月份包含在平均计算中,则需要生成缺失数据。
只需使用 dataframe ["month_id", "defaultValue"] 执行完全外连接,其中 month_id 是从 1 到 12 的值,defaultValue = 0。
另一种解决方案不是执行平均值,而是执行值的总和,然后除以月份数。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.