[英]Retrieve top n in each group of a DataFrame in pyspark
There's a DataFrame in pyspark with data as below: pyspark中有一个DataFrame,数据如下:
user_id object_id score
user_1 object_1 3
user_1 object_1 1
user_1 object_2 2
user_2 object_1 5
user_2 object_2 2
user_2 object_2 6
What I expect is returning 2 records in each group with the same user_id, which need to have the highest score.我期望的是在每个组中返回 2 条具有相同 user_id 的记录,这些记录需要具有最高分。 Consequently, the result should look as the following:
因此,结果应如下所示:
user_id object_id score
user_1 object_1 3
user_1 object_2 2
user_2 object_2 6
user_2 object_1 5
I'm really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem?我真的是 pyspark 的新手,谁能给我一个代码片段或这个问题的相关文档的门户? Great thanks!
太谢谢了!
I believe you need to use window functions to attain the rank of each row based on user_id
and score
, and subsequently filter your results to only keep the first two values.我相信您需要使用窗口函数根据
user_id
和score
获得每行的排名,然后过滤结果以仅保留前两个值。
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
In general, the official programming guide is a good place to start learning Spark.一般来说,官方编程指南是开始学习 Spark 的好地方。
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
Top-n is more accurate if using row_number
instead of rank
when getting rank equality:如果在获得秩相等时使用
row_number
而不是rank
,则 Top-n 更准确:
val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
.where(col('row_number') <= n) \
.limit(20) \
.toPandas()
Note
limit(20).toPandas()
trick instead ofshow()
for Jupyter notebooks for nicer formatting.Note
limit(20).toPandas()
技巧代替了 Jupyter 笔记本的show()
以获得更好的格式。
I know the question is asked for pyspark
and I was looking for the similar answer in Scala
ie我知道这个问题是针对
pyspark
提出的,我正在Scala
寻找类似的答案,即
Retrieve top n values in each group of a DataFrame in Scala
在Scala中检索DataFrame的每组中的前n个值
Here is the scala
version of @mtoto's answer.这是@mtoto 答案的
scala
版本。
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col
val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show()
# you can change the value 2 to any number you want. Here 2 represents the top 2 values
Here is another solution without a window function to get the top N records from pySpark DataFrame.这是另一个没有窗口函数的解决方案,可以从 pySpark DataFrame 获取前 N 条记录。
# Import Libraries
from pyspark.sql.functions import col
# Sample Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)
# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)
# Display DataFrame
sorted_df.show()
Output输出
+-------+---------+-----+
|user_id|object_id|score|
+-------+---------+-----+
| user_1| object_2| 2|
| user_2| object_2| 2|
| user_1| object_1| 3|
| user_2| object_1| 5|
| user_2| object_2| 6|
+-------+---------+-----+
If you are interested in more window functions in Spark you can refer to one of my blogs: https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86如果您对 Spark 中的更多窗口函数感兴趣,可以参考我的博客之一: https : //medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86
with Python 3 and Spark 2.4使用 Python 3 和 Spark 2.4
from pyspark.sql import Window
import pyspark.sql.functions as f
def get_topN(df, group_by_columns, order_by_column, n=1):
window_group_by_columns = Window.partitionBy(group_by_columns)
ordered_df = df.select(df.columns + [
f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
return topN_df
top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1)
To Find Nth highest value in PYSPARK SQLquery using ROW_NUMBER()
function:要使用
ROW_NUMBER()
函数在 PYSPARK SQLquery 中查找第 N 个最大值:
SELECT * FROM (
SELECT e.*,
ROW_NUMBER() OVER (ORDER BY col_name DESC) rn
FROM Employee e
)
WHERE rn = N
N is the nth highest value required from the column N 是该列所需的第 n 个最高值
Output:输出:
[Stage 2:> (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name |
+-----------+
|1183395 |
+-----------+
query will return N highest value查询将返回 N 个最高值
如何在同一个 spark SQL 查询中获得前 N 和后 N 条记录?
Adding to moto's answer, you can also consider using row_number , because rank sometimes gives same value to identical rows and this might give you different number of values per group .添加到 moto 的答案中,您还可以考虑使用row_number ,因为rank 有时会为相同的行提供相同的值,这可能会为每个 group 提供不同数量的值。 See below:
见下文:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col
win = Window.partitionBy(sdf["user_id"]).orderBy(sdf["score"].desc())
sdf.select("*", row_number().over(win).alias('row_n'))\
.filter(col('row_n') <= 2)\
.show()
#+-------+---------+-----+-----+
#|user_id|object_id|score|row_n|
#+-------+---------+-----+-----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+-----+
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.