簡體   English   中英

檢索pyspark中a DataFrame的每組中的top n

[英]Retrieve top n in each group of a DataFrame in pyspark

pyspark中有一個DataFrame,數據如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我期望的是在每個組中返回 2 條具有相同 user_id 的記錄,這些記錄需要具有最高分。 因此,結果應如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我真的是 pyspark 的新手,誰能給我一個代碼片段或這個問題的相關文檔的門戶? 太謝謝了!

我相信您需要使用窗口函數根據user_idscore獲得每行的排名,然后過濾結果以僅保留前兩個值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

一般來說,官方編程指南是開始學習 Spark 的好地方。

數據

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

如果在獲得秩相等時使用row_number而不是rank ,則 Top-n 更准確:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

Note limit(20).toPandas()技巧代替了 Jupyter 筆記本的show()以獲得更好的格式。

我知道這個問題是針對pyspark提出的,我正在Scala尋找類似的答案,即

在Scala中檢索DataFrame的每組中的前n個值

這是@mtoto 答案的scala版本。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

可以在此處找到更多示例。

這是另一個沒有窗口函數的解決方案,可以從 pySpark DataFrame 獲取前 N 條記錄。

# Import Libraries
from pyspark.sql.functions import col

# Sample Data
rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)

# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)

# Display DataFrame
sorted_df.show()

輸出

+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2|    2|
| user_2| object_2|    2|
| user_1| object_1|    3|
| user_2| object_1|    5|
| user_2| object_2|    6|
+-------+---------+-----+

如果您對 Spark 中的更多窗口函數感興趣,可以參考我的博客之一: https : //medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86

使用 Python 3 和 Spark 2.4

from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) 

要使用ROW_NUMBER()函數在 PYSPARK SQLquery 中查找第 N 個最大值:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

N 是該列所需的第 n 個最高值

輸出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

查詢將返回 N 個最高值

如何在同一個 spark SQL 查詢中獲得前 N 和后 N 條記錄?

添加到 moto 的答案中,您還可以考慮使用row_number ,因為rank 有時會為相同的行提供相同的值,這可能會為每個 group 提供不同數量的值 見下文:

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col

win = Window.partitionBy(sdf["user_id"]).orderBy(sdf["score"].desc())

sdf.select("*", row_number().over(win).alias('row_n'))\
  .filter(col('row_n') <= 2)\
  .show()

#+-------+---------+-----+-----+
#|user_id|object_id|score|row_n|
#+-------+---------+-----+-----+
#| user_1| object_1|    3|    1|
#| user_1| object_2|    2|    2|
#| user_2| object_2|    6|    1|
#| user_2| object_1|    5|    2|
#+-------+---------+-----+-----+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM