繁体   English   中英

pyspark window function 使用平均法计算问题

[英]pyspark window function calculation issue with avg method

我有一个输入 dataframe 如下:

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  02    |40     |50    
1002      |  02    |50     |60    
1003      |  02    |60     |70
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

使用下面的代码,我创建了两个新列,它们使用 window function 进行平均:

rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn("value1_1", F.avg("value1").over(rnum))
df = df.withColumn("value1_2", F.avg("value2").over(rnum))

Output:

partner_id|month_id|value1 |value2|value1_1|value2_2
1001      |  01    |10     |20    |10      |20
1002      |  01    |20     |30    |20      |30
1003      |  01    |30     |40    |30      |40
1001      |  02    |40     |50    |25      |35
1002      |  02    |50     |60    |35      |45
1003      |  02    |60     |70    |45      |55
1001      |  03    |70     |80    |40      |50
1002      |  03    |80     |90    |50      |60
1003      |  03    |90     |100   |60      |70

使用 pyspark Window function,累积平均值在 value1 和 value2 列上表现良好。 但是,如果我们在下面的输入中错过了一个月的数据,那么下个月的平均计算应该基于月份。 而不是正常的平均值。 例如,如果输入如下所示(缺少 02 月数据)

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

然后对第三个月记录的平均计算发生如下:例如:(70 + 10)/2 但是,如果缺少某些月份值,正确的平均方法是什么?

如果您使用的是火花 2.4+。 您可以使用序列 function 和数组函数。 此解决方案受此链接启发

    from pyspark.sql import functions as F
    from pyspark.sql.window import Window

    w= Window().partitionBy("partner_id")

    df1 =df.withColumn("month_seq", F.sequence(F.min("month_id").over(w), F.max("month_id").over(w), F.lit(1)))\
        .groupBy("partner_id").agg(F.collect_list("month_id").alias("month_id"), F.collect_list("value1").alias("value1"), F.collect_list("value2").alias("value2")
         ,F.first("month_seq").alias("month_seq")).withColumn("month_seq", F.array_except("month_seq","month_id"))\
        .withColumn("month_id",F.flatten(F.array("month_id","month_seq"))).drop("month_seq")\
        .withColumn("zip", F.explode(F.arrays_zip("month_id","value1", "value2"))) \
        .select("partner_id", "zip.month_id", F.when(F.col("zip.value1").isNull() , \
                                          F.lit(0)).otherwise(F.col("zip.value1")).alias("value1"),
                                          F.when(F.col("zip.value2").isNull(), F.lit(0)).otherwise(F.col("zip.value2")
                                                                                         ).alias("value2")).orderBy("month_id")

    rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))

    df2 = df1.withColumn("value1_1", F.avg("value1").over(rnum)).withColumn("value1_2", F.avg("value2").over(rnum))

    df2.show()

    # +----------+--------+------+------+------------------+------------------+
    # |partner_id|month_id|value1|value2|          value1_1|          value1_2|
    # +----------+--------+------+------+------------------+------------------+
    # |      1002|       1|    10|    20|              10.0|              20.0|
    # |      1002|       2|     0|     0|               5.0|              10.0|
    # |      1002|       3|    80|    90|              30.0|36.666666666666664|
    # |      1001|       1|    10|    10|              10.0|              10.0|
    # |      1001|       2|     0|     0|               5.0|               5.0|
    # |      1001|       3|    70|    80|26.666666666666668|              30.0|
    # |      1003|       1|    30|    40|              30.0|              40.0|
    # |      1003|       2|     0|     0|              15.0|              20.0|
    # |      1003|       3|    90|   100|              40.0|46.666666666666664|
    # +----------+--------+------+------+------------------+------------------+

Spark 不够聪明,无法理解缺少一个月,因为它甚至不知道一个月可能是什么。

如果您希望“缺失”月份包含在平均计算中,则需要生成缺失数据。

只需使用 dataframe ["month_id", "defaultValue"] 执行完全外连接,其中 month_id 是从 1 到 12 的值,defaultValue = 0。


另一种解决方案不是执行平均值,而是执行值的总和,然后除以月份数。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM