繁体   English   中英

PySpark 数据帧的条件聚合

[英]Conditional aggregate for a PySpark dataframe

我正在尝试对 PySpark 数据框执行条件聚合。

我尝试了 sum/avg,它似乎工作正常,但不知何故计数给出了错误的结果。

from pyspark.sql import functions as F

df = spark.createDataFrame([('a', '1', 2502, 332), 
                              ('b', '1', 2328, 56),
                              ('a', '1', 21, 78),
                              ('b', '2', 234, 23),
                              ('b', '2', 785, 12)
                             ],
                             ['x','id', 'y','z'])
df.show()

+---+---+----+---+
|  x| id|   y|  z|
+---+---+----+---+
|  a|  1|2502|332|
|  b|  1|2328| 56|
|  a|  1|  21| 78|
|  b|  2| 234| 23|
|  b|  2| 785| 12|
+---+---+----+---+
df_new = df.groupBy("id").agg(
                        F.avg(F.when((F.col("x") == 'a'), F.col('y'))
                               .otherwise(0)).alias('col1'),

                        F.count(F.when((F.col("x") == 'b'), F.col('y'))
                                 .otherwise(0)).alias('col2'),

                        F.sum(F.when((F.col("x") == 'b'), F.col('y'))
                               .otherwise(0)).alias('col3')
    )
df_new.show()

+---+-----+----+----+
| id| col1|col2|col3|
+---+-----+----+----+
|  1|841.0|   3|2328|
|  2|  0.0|   2|1019|
+---+-----+----+----+

理想情况下,计数应按行给出12 ,如预期的结果是:

+---+-----+----+----+
| id| col1|col2|col3|
+---+-----+----+----+
|  1|841.0|   1|2328|
|  2|  0.0|   2|1019|
+---+-----+----+----+

因为第二行是唯一符合id='1'x='b' 但出于某种原因,它显示为 3。

您需要从count删除.otherwise 因为0也会增加计数。

import pyspark.sql.functions as F

df_new = df.groupBy("id").agg(
                        F.avg(F.when((F.col("x") == 'a'), F.col('y')).otherwise(0)).alias('col1'),

                        F.count(F.when((F.col("x") == 'b'), F.col('y'))).alias('col2'),

                        F.sum(F.when((F.col("x") == 'b'), F.col('y')).otherwise(0)).alias('col3')
    )

df_new.show()

+---+-----+----+----+
| id| col1|col2|col3|
+---+-----+----+----+
|  1|841.0|   1|2328|
|  2|  0.0|   2|1019|
+---+-----+----+----+

或使用如下所示的sum

import pyspark.sql.functions as F

df_new = df.groupBy("id").agg(
                        F.avg(F.when((F.col("x") == 'a'), F.col('y')).otherwise(0)).alias('col1'),

                        F.sum(F.when((F.col("x") == 'b'), 1).otherwise(0)).alias('col2'),

                        F.sum(F.when((F.col("x") == 'b'), F.col('y')).otherwise(0)).alias('col3')
    )

df_new.show()
+---+-----+----+----+
| id| col1|col2|col3|
+---+-----+----+----+
|  1|841.0|   1|2328|
|  2|  0.0|   2|1019|
+---+-----+----+----+

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM