简体   繁体   中英

How to compute a cumulative sum under a limit with Spark?

After several tries and some research, I'm stuck on trying to solve the following problem with Spark.

I have a Dataframe of elements with a priority and a quantity.

+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

I have a fixed limit quantity :

+------+--------+
|family|limitQty|
+------+--------+
|    f1|     100|
+------+--------+

I want to mark as "ok" the elements whose the cumulative sum is under the limit. Here is the expected result :

+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1| -> 20 < 100   => ok
|    f1| elmt 2|       2| 40|  1| -> 20 + 40 < 100  => ok
|    f1| elmt 3|       3| 10|  1| -> 20 + 40 + 10 < 100   => ok
|    f1| elmt 4|       4| 50|  0| -> 20 + 40 + 10 + 50 > 100   => ko 
|    f1| elmt 5|       5| 40|  0| -> 20 + 40 + 10 + 40 > 100   => ko  
|    f1| elmt 6|       6| 10|  1| -> 20 + 40 + 10 + 10 < 100   => ok
|    f1| elmt 7|       7| 20|  1| -> 20 + 40 + 10 + 10 + 20 < 100   => ok
|    f1| elmt 8|       8| 10|  0| -> 20 + 40 + 10 + 10 + 20 + 10 > 100   => ko
+------+-------+--------+---+---+  

I try to solve if with a cumulative sum :

    initDF
      .join(limitQtyDF, Seq("family"), "left_outer")
      .withColumn("cumulSum", sum($"qty").over(Window.partitionBy("family").orderBy("priority")))
      .withColumn("ok", when($"cumulSum" <= $"limitQty", 1).otherwise(0))
      .drop("cumulSum", "limitQty")

But it's not enough because the elements after the element that is up to the limit are not take into account. I can't find a way to solve it with Spark. Do you have an idea ?

Here is the corresponding Scala code :

    val sparkSession = SparkSession.builder()
      .master("local[*]")
      .getOrCreate()

    import sparkSession.implicits._

    val initDF = Seq(
      ("f1", "elmt 1", 1, 20),
      ("f1", "elmt 2", 2, 40),
      ("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

    val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

    val expectedDF = Seq(
      ("f1", "elmt 1", 1, 20, 1),
      ("f1", "elmt 2", 2, 40, 1),
      ("f1", "elmt 3", 3, 10, 1),
      ("f1", "elmt 4", 4, 50, 0),
      ("f1", "elmt 5", 5, 40, 0),
      ("f1", "elmt 6", 6, 10, 1),
      ("f1", "elmt 7", 7, 20, 1),
      ("f1", "elmt 8", 8, 10, 0)
    ).toDF("family", "element", "priority", "qty", "ok").show()

Thank you for your help !

The solution is shown below:

scala> initDF.show
+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

scala> val df1 = initDF.groupBy("family").agg(collect_list("qty").as("comb_qty"), collect_list("priority").as("comb_prior"), collect_list("element").as("comb_elem"))
df1: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 2 more fields]

scala> df1.show
+------+--------------------+--------------------+--------------------+
|family|            comb_qty|          comb_prior|           comb_elem|
+------+--------------------+--------------------+--------------------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|
+------+--------------------+--------------------+--------------------+


scala> val df2 = df1.join(limitQtyDF, df1("family") === limitQtyDF("family")).drop(limitQtyDF("family"))
df2: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]

scala> df2.show
+------+--------------------+--------------------+--------------------+--------+
|family|            comb_qty|          comb_prior|           comb_elem|limitQty|
+------+--------------------+--------------------+--------------------+--------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|     100|
+------+--------------------+--------------------+--------------------+--------+


scala> def validCheck = (qty: Seq[Int], limit: Int) => {
     | var sum = 0
     | qty.map(elem => {
     | if (elem + sum <= limit) {
     | sum = sum + elem
     | 1}else{
     | 0
     | }})}
validCheck: (scala.collection.mutable.Seq[Int], Int) => scala.collection.mutable.Seq[Int]

scala> val newUdf = udf(validCheck)
newUdf: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,ArrayType(IntegerType,false),Some(List(ArrayType(IntegerType,false), IntegerType)))

val df3 = df2.withColumn("valid", newUdf(col("comb_qty"),col("limitQty"))).drop("limitQty")
df3: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]

scala> df3.show
+------+--------------------+--------------------+--------------------+--------------------+
|family|            comb_qty|          comb_prior|           comb_elem|               valid|
+------+--------------------+--------------------+--------------------+--------------------+
|    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|[1, 1, 1, 0, 0, 1...|
+------+--------------------+--------------------+--------------------+--------------------+

scala> val myUdf = udf((qty: Seq[Int], prior: Seq[Int], elem: Seq[String], valid: Seq[Int]) => {
     | elem zip prior zip qty zip valid map{
     | case (((a,b),c),d) => (a,b,c,d)}
     | }
     | )

scala> val df4 = df3.withColumn("combined", myUdf(col("comb_qty"),col("comb_prior"),col("comb_elem"),col("valid")))
df4: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 4 more fields]



scala> val df5 = df4.drop("comb_qty","comb_prior","comb_elem","valid")
df5: org.apache.spark.sql.DataFrame = [family: string, combined: array<struct<_1:string,_2:int,_3:int,_4:int>>]

scala> df5.show(false)
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|family|combined                                                                                                                                                        |
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|f1    |[[elmt 1, 1, 20, 1], [elmt 2, 2, 40, 1], [elmt 3, 3, 10, 1], [elmt 4, 4, 50, 0], [elmt 5, 5, 40, 0], [elmt 6, 6, 10, 1], [elmt 7, 7, 20, 1], [elmt 8, 8, 10, 0]]|
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+

scala> val df6 = df5.withColumn("combined",explode(col("combined")))
df6: org.apache.spark.sql.DataFrame = [family: string, combined: struct<_1: string, _2: int ... 2 more fields>]

scala> df6.show
+------+------------------+
|family|          combined|
+------+------------------+
|    f1|[elmt 1, 1, 20, 1]|
|    f1|[elmt 2, 2, 40, 1]|
|    f1|[elmt 3, 3, 10, 1]|
|    f1|[elmt 4, 4, 50, 0]|
|    f1|[elmt 5, 5, 40, 0]|
|    f1|[elmt 6, 6, 10, 1]|
|    f1|[elmt 7, 7, 20, 1]|
|    f1|[elmt 8, 8, 10, 0]|
+------+------------------+

scala> val df7 = df6.select("family", "combined._1", "combined._2", "combined._3", "combined._4").withColumnRenamed("_1","element").withColumnRenamed("_2","priority").withColumnRenamed("_3", "qty").withColumnRenamed("_4","ok")
df7: org.apache.spark.sql.DataFrame = [family: string, element: string ... 3 more fields]

scala> df7.show
+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1|
|    f1| elmt 2|       2| 40|  1|
|    f1| elmt 3|       3| 10|  1|
|    f1| elmt 4|       4| 50|  0|
|    f1| elmt 5|       5| 40|  0|
|    f1| elmt 6|       6| 10|  1|
|    f1| elmt 7|       7| 20|  1|
|    f1| elmt 8|       8| 10|  0|
+------+-------+--------+---+---+

Let me know if it helps!!

Another way to do it will be an RDD based approach by iterating row by row.

var bufferRow: collection.mutable.Buffer[Row] = collection.mutable.Buffer.empty[Row]
var tempSum: Double = 0
val iterator = df.collect.iterator
while(iterator.hasNext){
  val record = iterator.next()
  val y = record.getAs[Integer]("qty")
  tempSum = tempSum + y
  print(record)
  if (tempSum <= 100.0 ) {
    bufferRow = bufferRow ++ Seq(transformRow(record,1))
  }
  else{
    bufferRow = bufferRow ++ Seq(transformRow(record,0))
    tempSum = tempSum - y
  }
}

Defining transformRow function which is used to add a column to a row.

def transformRow(row: Row,flag : Int): Row =  Row.fromSeq(row.toSeq ++ Array[Integer](flag))

Next thing to do will be adding an additional column to the schema.

val newSchema = StructType(df.schema.fields ++ Array(StructField("C_Sum", IntegerType, false))

Followed by creating a new dataframe.

val outputdf = spark.createDataFrame(spark.sparkContext.parallelize(bufferRow.toSeq),newSchema)

Output Dataframe :

+------+-------+--------+---+-----+
|family|element|priority|qty|C_Sum|
+------+-------+--------+---+-----+
|    f1|  elmt1|       1| 20|    1|
|    f1|  elmt2|       2| 40|    1|
|    f1|  elmt3|       3| 10|    1|
|    f1|  elmt4|       4| 50|    0|
|    f1|  elmt5|       5| 40|    0|
|    f1|  elmt6|       6| 10|    1|
|    f1|  elmt7|       7| 20|    1|
|    f1|  elmt8|       8| 10|    0|
+------+-------+--------+---+-----+

I am new to Spark so this solution may not be optimal. I am assuming the value of 100 is an input to the program here. In that case:

case class Frame(family:String, element : String, priority : Int, qty :Int)

import scala.collection.JavaConverters._
val ans = df.as[Frame].toLocalIterator
  .asScala
  .foldLeft((Seq.empty[Int],0))((acc,a) => 
    if(acc._2 + a.qty <= 100) (acc._1 :+ a.priority, acc._2 + a.qty) else acc)._1

df.withColumn("OK" , when($"priority".isin(ans :_*), 1).otherwise(0)).show

results in:

+------+-------+--------+---+--------+
|family|element|priority|qty|OK      |
+------+-------+--------+---+--------+
|    f1| elmt 1|       1| 20|       1|
|    f1| elmt 2|       2| 40|       1|
|    f1| elmt 3|       3| 10|       1|
|    f1| elmt 4|       4| 50|       0|
|    f1| elmt 5|       5| 40|       0|
|    f1| elmt 6|       6| 10|       1|
|    f1| elmt 7|       7| 20|       1|
|    f1| elmt 8|       8| 10|       0|
+------+-------+--------+---+--------+

The idea is simply to get a Scala iterator and extract the participating priority values from it and then use those values to filter out the participating rows. Given this solution gathers all the data in memory on one machine, it could run into memory problems if the dataframe size is too large to fit in memory.

Cumulative sum for each group

from pyspark.sql.window import Window as window
from pyspark.sql.types import IntegerType,StringType,FloatType,StructType,StructField,DateType
schema = StructType() \
        .add(StructField("empno",IntegerType(),True)) \
        .add(StructField("ename",StringType(),True)) \
        .add(StructField("job",StringType(),True)) \
        .add(StructField("mgr",StringType(),True)) \
        .add(StructField("hiredate",DateType(),True)) \
        .add(StructField("sal",FloatType(),True)) \
        .add(StructField("comm",StringType(),True)) \
        .add(StructField("deptno",IntegerType(),True))

emp = spark.read.csv('data/emp.csv',schema)
dept_partition = window.partitionBy(emp.deptno).orderBy(emp.sal)
emp_win = emp.withColumn("dept_cum_sal", 
                         f.sum(emp.sal).over(dept_partition.rowsBetween(window.unboundedPreceding, window.currentRow)))
emp_win.show()

Results appear like below:

+-----+------+---------+----+----------+------+-------+------+------------ 
+
|empno| ename|      job| mgr|  hiredate|   sal|   comm|deptno|dept_cum_sal|
+-----+------+---------+----+----------+------+-------+------+------------ 
+
| 7369| SMITH|    CLERK|7902|1980-12-17| 800.0|   null|    20|       800.0|
| 7876| ADAMS|    CLERK|7788|1983-01-12|1100.0|   null|    20|      1900.0|
| 7566| JONES|  MANAGER|7839|1981-04-02|2975.0|   null|    20|      4875.0|
| 7788| SCOTT|  ANALYST|7566|1982-12-09|3000.0|   null|    20|      7875.0|
| 7902|  FORD|  ANALYST|7566|1981-12-03|3000.0|   null|    20|     10875.0|
| 7934|MILLER|    CLERK|7782|1982-01-23|1300.0|   null|    10|      1300.0|
| 7782| CLARK|  MANAGER|7839|1981-06-09|2450.0|   null|    10|      3750.0|
| 7839|  KING|PRESIDENT|null|1981-11-17|5000.0|   null|    10|      8750.0|
| 7900| JAMES|    CLERK|7698|1981-12-03| 950.0|   null|    30|       950.0|
| 7521|  WARD| SALESMAN|7698|1981-02-22|1250.0| 500.00|    30|      2200.0|
| 7654|MARTIN| SALESMAN|7698|1981-09-28|1250.0|1400.00|    30|      3450.0|
| 7844|TURNER| SALESMAN|7698|1981-09-08|1500.0|   0.00|    30|      4950.0|
| 7499| ALLEN| SALESMAN|7698|1981-02-20|1600.0| 300.00|    30|      6550.0|
| 7698| BLAKE|  MANAGER|7839|1981-05-01|2850.0|   null|    30|      9400.0|
+-----+------+---------+----+----------+------+-------+------+------------+

PFA the answer

val initDF = Seq(("f1", "elmt 1", 1, 20),("f1", "elmt 2", 2, 40),("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

sc.broadcast(limitQtyDF)


val joinedInitDF=initDF.join(limitQtyDF,Seq("family"),"left")

case class dataResult(family:String,element:String,priority:Int, qty:Int, comutedValue:Int, limitQty:Int,controlOut:String) 
val familyIDs=initDF.select("family").distinct.collect.map(_(0).toString).toList

def checkingUDF(inputRows:List[Row])={
var controlVarQty=0
val outputArrayBuffer=collection.mutable.ArrayBuffer[dataResult]()
val setLimit=inputRows.head.getInt(4) 
for(inputRow <- inputRows)
{
val currQty=inputRow.getInt(3) 
//val outpurForRec=
controlVarQty + currQty match {
case value if value <= setLimit => 
controlVarQty+=currQty
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ok")
case value => 
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ko")
}
//outputArrayBuffer+=Row(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),controlVarQty+currQty,setLimit,outpurForRec)
}
outputArrayBuffer.toList
}

val tmpAB=collection.mutable.ArrayBuffer[List[dataResult]]()
for (familyID <- familyIDs) // val familyID="f1"
{
val currentFamily=joinedInitDF.filter(s"family = '${familyID}'").orderBy("element", "priority").collect.toList
tmpAB+=checkingUDF(currentFamily)
}

tmpAB.toSeq.flatMap(x => x).toDF.show(false)

This works for me .

+------+-------+--------+---+------------+--------+----------+
|family|element|priority|qty|comutedValue|limitQty|controlOut|
+------+-------+--------+---+------------+--------+----------+
|f1    |elmt 1 |1       |20 |20          |100     |ok        |
|f1    |elmt 2 |2       |40 |60          |100     |ok        |
|f1    |elmt 3 |3       |10 |70          |100     |ok        |
|f1    |elmt 4 |4       |50 |120         |100     |ko        |
|f1    |elmt 5 |5       |40 |110         |100     |ko        |
|f1    |elmt 6 |6       |10 |80          |100     |ok        |
|f1    |elmt 7 |7       |20 |100         |100     |ok        |
|f1    |elmt 8 |8       |10 |110         |100     |ko        |
+------+-------+--------+---+------------+--------+----------+

Please do drop unnecessary columns from the output

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM