简体   繁体   中英

how to select first n row items based on multiple conditions in pyspark

Now I have data like this:

+----+----+
|col1|   d|
+----+----+
|   A|   4|
|   A|  10|
|   A|   3|
|   B|   3|
|   B|   6|
|   B|   4|
|   B| 5.5|
|   B|  13|
+----+----+

col1 is StringType, d is TimestampType, here I use DoubleType instead. I want to generate data based on conditions tuples. Given a tuple[(A,3.5),(A,8),(B,3.5),(B,10)] I want to have the result like

+----+---+
|col1|  d|
+----+---+
|   A|  4|
|   A| 10|
|   B|  4|
|   B| 13|
+----+---+

That is for each element in the tuple, we select from the pyspark dataframe the first 1 row that d is larger than the tuple number and col1 is equal to the tuple string. What I've already written is:

df_res=spark_empty_dataframe    
for (x,y) in tuples:
         dft=df.filter(df.col1==x).filter(df.d>y).limit(1)
         df_res=df_res.union(dft)

But I think this might have efficiency problem, I do not know if I were right.

A possible approach avoiding loops can be creating a dataframe from the tuple you have as input:

t = [('A',3.5),('A',8),('B',3.5),('B',10)]
ref=spark.createDataFrame([(i[0],float(i[1])) for i in t],("col1_y","d_y"))

Then we can join on the input dataframe( df ) on condition and then group on the keys and values of tuple which will be repeated to get the first value on each group, then drop the extra columns:

(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner').orderBy("col1","d")

.groupBy("col1_y","d_y").agg(F.first("col1").alias("col1"),F.first("d").alias("d"))

.drop("col1_y","d_y")).show()

+----+----+
|col1|   d|
+----+----+
|   A|10.0|
|   A| 4.0|
|   B| 4.0|
|   B|13.0|
+----+----+

Note, if order of the dataframe is important, you can assign an index column with monotonically_increasing_id and include them in the aggregation then orderBy the index column.

EDIT another way instead of ordering and get first directly with min :

(df.join(ref,(df.col1==ref.col1_y)&(df.d>ref.d_y),how='inner')

.groupBy("col1_y","d_y").agg(F.min("col1").alias("col1"),F.min("d").alias("d"))

.drop("col1_y","d_y")).show()

+----+----+
|col1|   d|
+----+----+
|   B| 4.0|
|   B|13.0|
|   A| 4.0|
|   A|10.0|
+----+----+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM