简体   繁体   中英

PySpark 3.3.0 - Aggregate sum with condition to avoid self join

Given the following dataframe structure:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

I'm trying to aggregate the data to get a failure rate. My resulting dataframe would look like this.

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|  getBooks|       0.25|
|getClasses|          1|
+----------+-----------+

I'm currently able to do this by creating a second dataframe which filters out the success rows, then join the two dataframes back together and create a new column that divides the sum of the failed count (for that endpoint) with the sum of the total count.

I'm trying to find a way to avoid creating a separate dataframe and then having to re-join them back together as it seems expensive and unnecessary. Is there a way to sum columns conditionally? I've been playing around with the syntax but am getting stuck.

If I could do something like this:

df.groupBy("endPoint").sum("count").when(outcome = "failure"))

that would be ideal but I'm having trouble with this and wonder if I'm missing something fundamental here.

You can use a when() within the sum aggregate.

data_sdf. \
    groupBy('end_point'). \
    agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
        func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
        ). \
    withColumn('failure_rate', 
               func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
               ). \
    show()

# +----------+-------------+-------------+------------+
# | end_point|failure_count|success_count|failure_rate|
# +----------+-------------+-------------+------------+
# |getClasses|            4|            0|         1.0|
# |  getBooks|            1|            3|        0.25|
# +----------+-------------+-------------+------------+

This is easily achieved by using Spark windows:

import pyspark.sql.functions as F
from pyspark.sql import SparkSession, Window

w = Window.partitionBy("endPoint")

(
    df.withColumn("total", F.sum("count").over(w))
    .withColumn("failureRate", F.col("count") / F.col("total"))
    .select("endPoint", "failureRate")
    .where(F.col("outcome") == "failure")
    .show()
)

You can do a pivot on your dataframe to get a wide version of your dataframe where the outcome strings are made into independent columns, containing the sum of the count column. From that dataframe, you calculate your failure rate:

import pyspark.sql.functions as F
# init example table
df = spark.createDataFrame(
    [
        ("getBooks", 3, "success"),
        ("getBooks", 1, "failure"),
        ("getClasses", 0, "success"),
        ("getClasses", 4, "failure"),
    ],
    ["endPoint", "count", "outcome"],
)
df.show()
df_pivot = df.groupBy("endPoint").pivot("outcome", ["success", "failure"]).sum("count")
df_pivot.show()
df_total = df_pivot.withColumn("total", F.col("success") + F.col("failure"))
df_total.show()
df_failure_rate = df_total.select("endPoint", (F.col("failure") / F.col("total")).alias("failureRate"))
df_failure_rate.show()

Output:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

+----------+-------+-------+
|  endPoint|success|failure|
+----------+-------+-------+
|getClasses|      0|      4|
|  getBooks|      3|      1|
+----------+-------+-------+

+----------+-------+-------+-----+
|  endPoint|success|failure|total|
+----------+-------+-------+-----+
|getClasses|      0|      4|    4|
|  getBooks|      3|      1|    4|
+----------+-------+-------+-----+

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|getClasses|        1.0|
|  getBooks|       0.25|
+----------+-----------+

Here's the most efficient (well it's a ties the other solution I provided) which builds off the answer from @samkart.

So know it would just be which one you find easier to comprehend.

df. \
 filter(func.col("outcome")== "failure").\
 groupBy('end_point'). \
     agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
         func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
         ). \
     withColumn('failure_rate', 
                func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
                ). \
 explain()

Explain

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

Has been run with sample data from fskj

But generally this give you the idea of what I want to do. It might be worth looking at the explain() of this plan vs the other ones to see what is more efficient. (Really it's the only way to determine what's' better)

import pyspark.sql.functions as F
result = df. \
    groupBy('endPoint','outcome'). \
    agg( F.sum('outcome').alias("sum"),F.count("endPoint").alias("count") ). \
    where( F.col('outcome') != "success" ). \
    withColumn('failure_rate', 
               F.col('sum') / F.col('count') ). \
    select('endPoint','failure_rate')

explain for my solution: (I believe the to be the most efficient, as it uses predicate pushdown to remove "success" data early and therefore operates on less data. It also does not require a sort.)

== Physical Plan ==
*(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(cast(outcome#80 as double)), count(endPoint#78)])
+- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
   +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(cast(outcome#80 as double)), partial_count(endPoint#78)])
      +- *(1) Project [endPoint#78, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && NOT (outcome#80 = success))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

window explain

== Physical Plan ==
*(2) Project [endPoint#78, (cast(count#79L as double) / cast(total#1016L as double)) AS failureRate#1021]
+- *(2) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
   +- Window [sum(count#79L) windowspecdefinition(endPoint#78, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS total#1016L], [endPoint#78]
      +- *(1) Sort [endPoint#78 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(endPoint#78, 200)
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

Pivot explain: (not as efficient)

df_failure_rate.explain()
== Physical Plan ==
HashAggregate(keys=[endPoint#78], functions=[pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
+- Exchange hashpartitioning(endPoint#78, 200)
   +- HashAggregate(keys=[endPoint#78], functions=[partial_pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
      +- *(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(count#79L)])
         +- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
            +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(count#79L)])
               +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

when answer explain : (very efficient if the data set is small, as it doesn't filter out 'success' results.)

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM