简体   繁体   中英

sum previous row value in current row in spark scala

I am trying to adjust one of column values based on value in some other data frame. While doing this, if left over amount more then I need to carry forward to next row and calculate final amount.

During this operation, I am not able to hold of previous row left over amount to next row operation. I tried using lag window function and by taking running totals options but those are not working as expected.

I am work with Scala. Here is input data

val consumption = sc.parallelize(Seq((20180101, 600), (20180201, 900),(20180301, 400),(20180401, 600),(20180501, 1000),(20180601, 1900),(20180701, 500),(20180801, 100),(20180901, 500))).toDF("Month","Usage")
consumption.show()
+--------+-----+
|   Month|Usage|
+--------+-----+
|20180101|  600|
|20180201|  900|
|20180301|  400|
|20180401|  600|
|20180501| 1000|
|20180601| 1900|
|20180701|  500|
|20180801|  100|
|20180901|  500|
+--------+-----+
val promo = sc.parallelize(Seq((20180101, 1000),(20180201, 100),(20180401, 3000))).toDF("PromoEffectiveMonth","promoAmount")
promo.show()
+-------------------+-----------+
|PromoEffectiveMonth|promoAmount|
+-------------------+-----------+
|           20180101|       1000|
|           20180201|        100|
|           20180401|       3000|
+-------------------+-----------+

expected result:

val finaldf = sc.parallelize(Seq((20180101,600,400,600),(20180201,900,0,400),(20180301,400,0,0),(20180401,600,2400,600),(20180501,1000,1400,1000),(20180601,1900,0,500),(20180701,500,0,0),(20180801,100,0,0),(20180901,500,0,0))).toDF("Month","Usage","LeftOverPromoAmt","AdjustedUsage")
finaldf.show()
+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

The logic what I am applying is based on Month and PromoEffective join, need to apply promo amount on consumption usage column till promo amount become zero.

eg: in Jan'18 month, promoamount is 1000, after deducting from usage (600), the left over promo amt is 400 and adj usage is 600. the left over over 400 will be considered for next month and there promo amt for Feb then final promo amount available is 500. here usage is more when compare to usage.

So left over promo amount is zero and adjust usage is 400 (900 - 500).

First of all, you need to perform a left_outer join so that for each row you have its corresponding promotion. The join operation is performed by means of the fields Month and PromoEffectiveMonth from the datasets Consumption and promo , respectively. Note also that I have created a new column, Timestamp . It has been created by using the Spark SQL unix_timestamp function. It will be used to sort the dataset by date.

val ds = consumption
    .join(promo, consumption.col("Month") === promo.col("PromoEffectiveMonth"), "left_outer")
    .select("UserID", "Month", "Usage", "promoAmount")
    .withColumn("Timestamp", unix_timestamp($"Month".cast("string"), "yyyyMMdd").cast(TimestampType))

This is the result of these operations.

+--------+-----+-----------+-------------------+
|   Month|Usage|promoAmount|          Timestamp|
+--------+-----+-----------+-------------------+
|20180301|  400|       null|2018-03-01 00:00:00|
|20180701|  500|       null|2018-07-01 00:00:00|
|20180901|  500|       null|2018-09-01 00:00:00|
|20180101|  600|       1000|2018-01-01 00:00:00|
|20180801|  100|       null|2018-08-01 00:00:00|
|20180501| 1000|       null|2018-05-01 00:00:00|
|20180201|  900|        100|2018-02-01 00:00:00|
|20180601| 1900|       null|2018-06-01 00:00:00|
|20180401|  600|       3000|2018-04-01 00:00:00|
+--------+-----+-----------+-------------------+

Next, you have to create a Window . Window functions are used to perform calculations over a group of records by using some criteria (more info on this here ). In our case, the criteria is to sort each group by Timestamp .

 val window = Window.orderBy("Timestamp")

Okay, now comes the hardest part. You need to create a User Defined Aggregate Function . In this function, each group will be processed according to a custom operation, and it will enable you to process each row by taking into account the value of the previous one.

  class CalculatePromos extends UserDefinedAggregateFunction {
    // Input schema for this UserDefinedAggregateFunction
    override def inputSchema: StructType =
      StructType(
        StructField("Usage", LongType) ::
        StructField("promoAmount", LongType) :: Nil)

    // Schema for the parameters that will be used internally to buffer temporary values
    override def bufferSchema: StructType = StructType(
        StructField("AdjustedUsage", LongType) ::
        StructField("LeftOverPromoAmt", LongType) :: Nil
    )

    // The data type returned by this UserDefinedAggregateFunction.
    // In this case, it will return an StructType with two fields: AdjustedUsage and LeftOverPromoAmt
    override def dataType: DataType = StructType(Seq(StructField("AdjustedUsage", LongType), StructField("LeftOverPromoAmt", LongType)))

    // Whether this UDAF is deterministic or not. In this case, it is
    override def deterministic: Boolean = true

    // Initial values for the temporary values declared above
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }

    // In this function, the values associated to the buffer schema are updated
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

      val promoAmount = if(input.isNullAt(1)) 0L else input.getLong(1)
      val leftOverAmount = buffer.getLong(1)
      val usage = input.getLong(0)
      val currentPromo = leftOverAmount + promoAmount

      if(usage < currentPromo) {
        buffer(0) = usage
        buffer(1) = currentPromo - usage
      } else {
        if(currentPromo == 0)
          buffer(0) = 0L
        else
          buffer(0) = usage - currentPromo
        buffer(1) = 0L
      }
    }

    // Function used to merge two objects. In this case, it is not necessary to define this method since
    // the whole logic has been implemented in update
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}

    // It is what you will return. In this case, a tuple of the buffered values which rerpesent AdjustedUsage and LeftOverPromoAmt
    override def evaluate(buffer: Row): Any = {
      (buffer.getLong(0), buffer.getLong(1))
    }

  }

Basically, it creates a function which can be used in Spark SQL that receives two columns ( Usage and promoAmount , as specified in the method inputSchema ), and returns a new column with two subcolums ( AdjustedUsage and LeftOverPromAmt , as defined in the method dataType ). The method bufferSchema enables you to create temporary value to be used to support the operations. In this case, I have defined AdjustedUsage and LeftOverPromoAmt .

The logic you are applying is implemented in the method update . Basically, that takes the values previously calculated and updates them. The argument buffer contains the temporary values defined in bufferSchema , and input keeps the value of the row that is being processed in that moment. Finally, evaluate returns a tuple object containing the result of the operations for each row, in this case, the temporary values defined in bufferSchema and updated in the method update .

The next step is to create a variable by instantiating the class CalculatePromos .

val calculatePromos = new CalculatePromos

Finally, you have to apply the User Defined Aggregate Function calculatePromos by using the method withColumn of the dataset. Note that you have to pass it the input columns ( Usage and promoAmount ) and apply the window by using the method over.

ds
  .withColumn("output", calculatePromos($"Usage", $"promoAmount").over(window))
  .select($"Month", $"Usage", $"output.LeftOverPromoAmt".as("LeftOverPromoAmt"), $"output.AdjustedUsage".as("AdjustedUsage"))

This is the result:

+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

Hope it helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM