简体   繁体   中英

Reduce rows in a group based on a column value using Spark / Scala

I want to implement Netting in the sense of reducing rows in each group based on below conditions: - if the UNITS column has negative and positive values across rows in each group, do arithmetic sum. And the final row will have Amt from the row that has the more amount. - If Units has only positive or negative values in a group, then we pass the all the rows as-is

Out of this below dataset below, I want to do netting but am unable to figure out as this is not aggregation:

+-----+------+----+-----+
|store|prod  |amt |units|
+-----+------+----+-----+
|West |Apple |2.0 |-10  |
|West |Apple |3.0 |10   |
|West |Orange|5.0 |-15  |
|West |Orange|17.0|-15  |
|South|Orange|3.0 |9    |
|South|Orange|6.0 |-18  |
|East |Milk  |5.0 |-5   |
|West |Milk  |5.0 |8    |
+-----+------+----+-----+

Summing should happen if there is atleast 2 rows with opposite sign Units. That is in group below:

+-----+------+----+-----+
|West |Apple |2.0 |-10  |
|West |Apple |3.0 |10   |
+-----+------+----+-----+

There are 2 rows with -10 and 10, so this group will reduce to zero rows, as summing -10 and 10 is Zero.

But in below group, the difference between the Units is -9:

+-----+------+----+-----+
|South|Orange|3.0 |9    |
|South|Orange|6.0 |-18  |
+-----+------+----+-----+

... the result of this group would be taking - amt from the row that has more units. - difference in units.

+-----+------+----+-----+
|South|Orange|6.0 |-9  |
+-----+------+----+-----+

Any rows in group that dont contain both negative and positive numbers would pass-thru straight.

So the final dataset should look as below:

+-----+------+----+-----+
|store|prod  |amt |units|
+-----+------+----+-----+
|West |Orange|5.0 |-15  |
|West |Orange|17.0|-15  |
|South|Orange|6.0 |-9  |
|East |Milk  |5.0 |-5   |
|West |Milk  |5.0 |8    |
+-----+------+----+-----+

Below rows are ... a) removed

+-----+------+----+-----+
|West |Apple |2.0 |-10  |
|West |Apple |3.0 |10   |
+-----+------+----+-----+

or

b) reduced

+-----+------+----+-----+
|South|Orange|3.0 |9    |
|South|Orange|6.0 |-18  |
+-----+------+----+-----+

to

+-----+------+----+-----+
|South|Orange|6.0 |-9  |
+-----+------+----+-----+

Just wrapping my comment up in an answer. I've checked the following code in the spark-shell and it worked.

# create dataframe
val data = Seq(
("West" ,"Apple" ,2.0 ,-10  ),
("West" ,"Apple" ,3.0 ,10   ),
("West" ,"Orange",5.0 ,-15  ),
("West" ,"Orange",17.0,-15  ),
("South","Orange",3.0 ,9    ),
("South","Orange",6.0 ,-18  ),
("East" ,"Milk"  ,5.0 ,-5   ),
("West" ,"Milk"  ,5.0 ,8    ))
val df_raw = spark.createDataFrame(data)
val col_names = Seq("store", "prod", "amt", "units")
val df = df_raw.toDF(col_names: _*)

# define window
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"prod", $"store")

# add should_reduce flag via window function
val should_reduce_df = df.withColumn("should_reduce", (max($"units").over(w) > 0) && (min($"units").over(w) < 0))

# select those that are only passed on / not reduced
val pass_df = should_reduce_df.filter(!$"should_reduce").select(col_names.head, col_names.tail: _*)

# reduce those that we need to reduce
# maximum amount and sum of units
# finally filter out those that reduce to 0 units
val reduced_df = should_reduce_df.filter($"should_reduce").groupBy($"store", $"prod").agg(max($"amt").alias("amt"), sum($"units").alias("units")).filter($"units" !== 0)

# do a union of passed-on and reduced df
val final_df = pass_df.union(reduced_df)
final_df.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM