簡體   English   中英

PySpark 3.3.0 - 有條件的聚合總和以避免自加入

[英]PySpark 3.3.0 - Aggregate sum with condition to avoid self join

給定以下 dataframe 結構:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

我正在嘗試匯總數據以獲得失敗率。 我得到的 dataframe 看起來像這樣。

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|  getBooks|       0.25|
|getClasses|          1|
+----------+-----------+

我目前能夠通過創建第二個 dataframe 來過濾掉成功行,然后將兩個數據幀重新連接在一起並創建一個新列,將失敗計數的總和(對於該端點)除以總和總數。

我試圖找到一種方法來避免創建單獨的 dataframe 然后不得不將它們重新連接在一起,因為這似乎既昂貴又不必要。 有沒有辦法有條件地對列求和? 我一直在玩弄語法,但被卡住了。

如果我能做這樣的事情:

df.groupBy("endPoint").sum("count").when(outcome = "failure"))

那將是理想的,但我遇到了麻煩,想知道我是否在這里遺漏了一些基本的東西。

您可以在sum聚合中使用when()

data_sdf. \
    groupBy('end_point'). \
    agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
        func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
        ). \
    withColumn('failure_rate', 
               func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
               ). \
    show()

# +----------+-------------+-------------+------------+
# | end_point|failure_count|success_count|failure_rate|
# +----------+-------------+-------------+------------+
# |getClasses|            4|            0|         1.0|
# |  getBooks|            1|            3|        0.25|
# +----------+-------------+-------------+------------+

這可以通過使用 Spark windows 輕松實現:

import pyspark.sql.functions as F
from pyspark.sql import SparkSession, Window

w = Window.partitionBy("endPoint")

(
    df.withColumn("total", F.sum("count").over(w))
    .withColumn("failureRate", F.col("count") / F.col("total"))
    .select("endPoint", "failureRate")
    .where(F.col("outcome") == "failure")
    .show()
)

您可以在pivot上執行 pivot 以獲得 dataframe 的寬版本,其中結果字符串被制成獨立的列,包含count列的總和。 根據 dataframe,您可以計算出故障率:

import pyspark.sql.functions as F
# init example table
df = spark.createDataFrame(
    [
        ("getBooks", 3, "success"),
        ("getBooks", 1, "failure"),
        ("getClasses", 0, "success"),
        ("getClasses", 4, "failure"),
    ],
    ["endPoint", "count", "outcome"],
)
df.show()
df_pivot = df.groupBy("endPoint").pivot("outcome", ["success", "failure"]).sum("count")
df_pivot.show()
df_total = df_pivot.withColumn("total", F.col("success") + F.col("failure"))
df_total.show()
df_failure_rate = df_total.select("endPoint", (F.col("failure") / F.col("total")).alias("failureRate"))
df_failure_rate.show()

Output:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

+----------+-------+-------+
|  endPoint|success|failure|
+----------+-------+-------+
|getClasses|      0|      4|
|  getBooks|      3|      1|
+----------+-------+-------+

+----------+-------+-------+-----+
|  endPoint|success|failure|total|
+----------+-------+-------+-----+
|getClasses|      0|      4|    4|
|  getBooks|      3|      1|    4|
+----------+-------+-------+-----+

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|getClasses|        1.0|
|  getBooks|       0.25|
+----------+-----------+

這是最有效的(嗯,它與我提供的其他解決方案有聯系),它建立在@samkart 的答案之上。

所以要知道它只是你覺得更容易理解的那個。

df. \
 filter(func.col("outcome")== "failure").\
 groupBy('end_point'). \
     agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
         func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
         ). \
     withColumn('failure_rate', 
                func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
                ). \
 explain()

解釋

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

已使用來自 fskj 的示例數據運行

但通常這會讓你知道我想做什么。 可能值得看一下這個計划的explain()與其他計划,看看什么更有效。 (真的,這是確定什么更好的唯一方法)

import pyspark.sql.functions as F
result = df. \
    groupBy('endPoint','outcome'). \
    agg( F.sum('outcome').alias("sum"),F.count("endPoint").alias("count") ). \
    where( F.col('outcome') != "success" ). \
    withColumn('failure_rate', 
               F.col('sum') / F.col('count') ). \
    select('endPoint','failure_rate')

解釋我的解決方案:(我認為這是最有效的,因為它使用謂詞下推來盡早刪除“成功”數據,因此對更少的數據進行操作。它也不需要排序。)

== Physical Plan ==
*(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(cast(outcome#80 as double)), count(endPoint#78)])
+- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
   +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(cast(outcome#80 as double)), partial_count(endPoint#78)])
      +- *(1) Project [endPoint#78, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && NOT (outcome#80 = success))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

window explain

== Physical Plan ==
*(2) Project [endPoint#78, (cast(count#79L as double) / cast(total#1016L as double)) AS failureRate#1021]
+- *(2) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
   +- Window [sum(count#79L) windowspecdefinition(endPoint#78, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS total#1016L], [endPoint#78]
      +- *(1) Sort [endPoint#78 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(endPoint#78, 200)
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

Pivot 解釋:(效率不高)

df_failure_rate.explain()
== Physical Plan ==
HashAggregate(keys=[endPoint#78], functions=[pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
+- Exchange hashpartitioning(endPoint#78, 200)
   +- HashAggregate(keys=[endPoint#78], functions=[partial_pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
      +- *(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(count#79L)])
         +- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
            +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(count#79L)])
               +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

回答when explain :(如果數據集很小,則非常有效,因為它不會過濾掉“成功”結果。)

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM