简体   繁体   English

PySpark DataFrame:查找最接近的值并切片该DataFrame

[英]PySpark DataFrame: Find closest value and slice the DataFrame

So, I've done enough research and haven't found a post that addresses what I want to do. 因此,我已经进行了足够的研究,但是没有找到可以解决我想要做的事情的帖子。

I have a PySpark DataFrame my_df which is sorted by value column- 我有一个PySpark DataFrame my_df ,它sorted valuesorted

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
|   F|   15|
|   G|   10|
+----+-----+

The summation of all the counts in value column is equal to 136 . value列中所有计数的总和等于136 I want to get all the rows whose combined values >= x% of 136 . 我想获取所有combined values >= x% of 136 In this example, let's say x=80 . 在此示例中,假设x=80 Then target sum = 0.8*136 = 108.8 . 然后target sum = 0.8*136 = 108.8 Hence, the new DataFrame will consist of all the rows that have a combined value >= 108.8 . 因此,新的DataFrame将由combined value >= 108.8的所有行组成。

In our example, this would come down to row D (since combined values upto D = 30+25+20+18 = 93 ). 在我们的示例中,这将下降到D行(因为组合值最高D = 30+25+20+18 = 93 )。

However, the hard part is that I also want to include the immediately following rows with duplicate values. 但是,困难的部分是我还想包括紧随其后的重复值行。 In this case, I also want to include row E since it has the same value as row D ie 18 . 在这种情况下,我还想包含E行,因为它具有与D行相同的值,即18

I want to slice my_df by giving a percentage x variable, for example 80 as discussed above. 我想通过给x百分比变量来切片my_df ,例如上面讨论的80 The new DataFrame should consist of the following rows- 新的DataFrame应该包含以下几行:

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
+----+-----+

One thing I could do here is iterate through the DataFrame (which is ~360k rows) , but I guess that defeats the purpose of Spark. 我在这里可以做的一件事是遍历DataFrame (which is ~360k rows) ,但我想这违反了Spark的目的。

Is there a concise function for what I want here? 我在这里想要的功能是否简洁?

Use pyspark SQL functions to do this concisely. 使用pyspark SQL函数可以做到这一点。

result = my_df.filter(my_df.value > target).select(my_df.name,my_df.value)
result.show()

Edit: Based on OP's question edit - Compute running sum and get rows until the target value is reached. 编辑:根据OP的问题进行编辑-计算运行总和并获取行,直到达到目标值为止。 Note that this will result in rows upto D, not E..which seems like a strange requirement. 请注意,这将导致直到D的行,而不是E ..,这似乎是一个奇怪的要求。

from pyspark.sql import Window
from pyspark.sql import functions as f

# Total sum of all `values`
target = (my_df.agg(sum("value")).collect())[0][0]

w = Window.orderBy(my_df.name) #Ideally this should be a column that specifies ordering among rows
running_sum_df = my_df.withColumn('rsum',f.sum(my_df.value).over(w))
running_sum_df.filter(running_sum_df.rsum <= 0.8*target)

Your requirements are quite strict, so it's difficult to formulate an efficient solution to your problem. 您的要求非常严格,因此很难为您的问题制定有效的解决方案。 Nevertheless, here is one approach: 不过,这是一种方法:

First calculate the cumulative sum and the total sum for the value column and filter the DataFrame using the percentage of target condition you specified. 首先计算value列的累计和和总和,然后使用您指定的目标条件的百分比过滤DataFrame。 Let's call this result df_filtered : 我们将此结果df_filtered

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)
target = 0.8

df_filtered = df.withColumn("cum_sum", f.sum("value").over(w))\
    .withColumn("total_sum", f.sum("value").over(Window.partitionBy()))\
    .where(f.col("cum_sum") <= f.col("total_sum")*target)

df_filtered.show()
#+----+-----+-------+---------+
#|name|value|cum_sum|total_sum|
#+----+-----+-------+---------+
#|   A|   30|     30|      136|
#|   B|   25|     55|      136|
#|   C|   20|     75|      136|
#|   D|   18|     93|      136|
#+----+-----+-------+---------+

Then join this filtered DataFrame back on the original on the value column. 然后,将此过滤后的DataFrame重新加入到value列的原始位置。 Since your DataFrame is already sorted by value , the final output will contain the rows you want. 由于您的DataFrame已经按value排序,所以最终输出将包含所需的行。

df.alias("r")\
    .join(
    df_filtered.alias('l'),
    on="value"
).select("r.name", "r.value").sort(f.col("value").desc(), "name").show()
#+----+-----+
#|name|value|
#+----+-----+
#|   A|   30|
#|   B|   25|
#|   C|   20|
#|   D|   18|
#|   E|   18|
#+----+-----+

The total_sum and cum_sum columns are calculated using a Window function . 使用Window函数计算 total_sumcum_sum列。

The Window w orders on the value column descending, followed by the name column. 窗口wvalue列上降序,然后在name列上降序。 The name column is used to break ties- without it, both rows C and D would have the same cumulative sum of 111 = 75+18+18 and you'd incorrectly lose both of them in the filter. name列用于打破平局-没有它, CD行的累积总和为111 = 75+18+18并且您会在过滤器中错误地丢失它们。

w = Window\                                     # Define Window
    .orderBy(                                   # This will define ordering
        f.col("value").desc(),                  # First sort by value descending
        "name"                                  # Sort on name second
    )\
    .rangeBetween(Window.unboundedPreceding, 0) # Extend back to beginning of window

The rangeBetween(Window.unboundedPreceding, 0) specifies that the Window should include all rows before the current row (defined by the orderBy ). rangeBetween(Window.unboundedPreceding, 0)指定Window应该包括当前行(由orderBy定义)之前的所有行。 This is what makes it a cumulative sum. 这就是它的累加总和。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM