简体   繁体   中英

Calculating column value in current row of Spark Dataframe based on the calculated value of a different column in previous row using Scala

Supposing I have a Dataframe like below

Id A B C D
1 100 10 20 5
2 0 5 10 5
3 0 7 2 3
4 0 1 3 7

And the above needs to be converted to something like below

Id A B C D E
1 100 10 20 5 75
2 75 5 10 5 60
3 60 7 2 3 50
4 50 1 3 7 40

The thing works by the details provided below

  1. The data frame now has a new column E which for row 1 is calculated as col(A) - (max(col(B), col(C)) + col(D)) => 100-(max(10,20) + 5) = 75
  2. In the row with Id 2, the value of col E from row 1 is brough forward as the value for Col A
  3. So, for row 2, the column E , is determined as 75-(max(5,10) + 5) = 60
  4. Similarly in the row with Id 3, the value of A becomes 60 and the new value for col E is determined based on this

The problem is, the value of col A is dependent on the previous row's values except for the first row

Is there a possibility to solve this using windowing and lag

You can use collect_list function over a Window ordered by Id column and get cumulative array of structs that hold the values of A and max(B, C) + D (as field T ). Then, apply aggregate to calculate column E .

Note that in this particular case you can't use lag window function as you want the get calculated values recursively.

import org.apache.spark.sql.expressions.Window

val df2 = df.withColumn(
  "tmp",
  collect_list(
    struct(col("A"), (greatest(col("B"), col("C")) + col("D")).as("T"))
  ).over(Window.orderBy("Id"))
).withColumn(
  "E",
  expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc + x)")
).withColumn(
  "A",
  col("E") + greatest(col("B"), col("C")) + col("D")
).drop("tmp")

df2.show(false)

//+---+---+---+---+---+---+
//|Id |A  |B  |C  |D  |E  |
//+---+---+---+---+---+---+
//|1  |100|10 |20 |5  |75 |
//|2  |75 |5  |10 |5  |60 |
//|3  |60 |7  |2  |3  |50 |
//|4  |50 |1  |3  |7  |40 |
//+---+---+---+---+---+---+

You can show the intermediary column tmp to understand the logic behind the calculation.

As blackbishop said , you can't use lag function to retrieve changing value of a column. As you're using the scala API, you can develop your own User-Defined Aggregate Function

You create the following case classes, representing the row you're currently reading and your aggregator's buffer:

case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)

case class Buffer(var E: Integer, var A: Integer)

Then you use them to define your RecursiveAggregator custom aggregator:

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
  override def zero: Buffer = Buffer(null, null)

  override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
    buffer.A = if (buffer.E == null) currentRow.A else buffer.E
    buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) + currentRow.D)
    buffer
  }

  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    throw new NotImplementedError("should be used only over ordered window")
  }

  override def finish(reduction: Buffer): Buffer = reduction

  override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]

  override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
}

Finally you transform your RecursiveAggregator to an User-Defined aggregate function that you apply on your input dataframe:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val recursiveAggregator = udaf(RecursiveAggregator)

val window = Window.orderBy("Id")

val result = input
  .withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
  .select("Id", "computed.A", "B", "C", "D", "computed.E")

If you take your question's dataframe as input dataframe, you get the following result dataframe:

+---+---+---+---+---+---+
|id |A  |B  |C  |D  |E  |
+---+---+---+---+---+---+
|1  |100|10 |20 |5  |75 |
|2  |75 |5  |10 |5  |60 |
|3  |60 |7  |2  |3  |50 |
|4  |50 |1  |3  |7  |40 |
+---+---+---+---+---+---+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM