Cumulative sum from the beginning of a stream in Spark

I have to compute a cumulative sum on a value column by group from the beginning of the time series with a daily output.

If I do with a batch, it should be something like this:

val columns = Seq("timestamp", "group", "value")
val data = List(
  (Instant.parse("2020-01-01T00:00:00Z"), "Group1",  0),
  (Instant.parse("2020-01-01T00:00:00Z"), "Group2",  0),
  (Instant.parse("2020-01-01T12:00:00Z"), "Group1",  1),
  (Instant.parse("2020-01-01T12:00:00Z"), "Group2", -1),
  (Instant.parse("2020-01-02T00:00:00Z"), "Group1",  2),
  (Instant.parse("2020-01-02T00:00:00Z"), "Group2", -2),
  (Instant.parse("2020-01-02T12:00:00Z"), "Group1",  3),
  (Instant.parse("2020-01-02T12:00:00Z"), "Group2", -3),

val df = spark
  .toDF(columns: _*)

// defines a window from the beginning by `group`
val event_window = Window
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

val computed_df = df
      .over(event_window) // apply the aggregation on a window from the beginning
  .groupBy(window($"timestamp", "1 day"), $"group")
  .agg(functions.last("cumsum").as("cumsum_by_day")) // display the last value for each day

computed_df.show(truncate = false)

and the output is

|window                                    |group |cumsum_by_day|
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group1| 1           |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group1| 6           |
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group2|-1           |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group2|-6           |

The result is perfectly fine.

However, in my case, the data source is not an existing dataset but a stream and I didn't find any solution to apply the aggregation from the beginning of the stream, not on a sliding window.

The closest code I can do is:

// MemoryStream to reproduce locally the issue
implicit val sqlCtx: SQLContext = spark.sqlContext
val memoryStream = MemoryStream[(Instant, String, Int)]
val df = memoryStream
  .toDF(columns: _*)

val computed_df = df
  .groupBy(window($"timestamp", "1 day"), $"group")

  .option("truncate", value = false)

It produces an aggregation for each day but not from the beginning of the stream.

If I try to add something like .over(event_window) (like in batch), it compiles but fails at runtime.

How can we apply an aggregation function from the beginning of a stream?

Here aGitHub repository with all the context to run that code.

I didn't find any solution using the high-level functions. For example, it is not possible to add another groupBy over the main aggregation agg(functions.sum('value).as("agg"), functions.last('timestamp).as("ts") to get the daily report.

After many experiments, I switched to the low level functions. The most polyvalent function seems to be flatMapGroupsWithState .

// same `events` Dataframe as before 

// Accumulate value by group and report every day
val computed_df = events
  .withWatermark("timestamp", "0 second") // watermarking required to use GroupStateTimeout
  .as[(Instant, String, Int)]
  .groupByKey(event => event._2)
  .flatMapGroupsWithState[IntermediateState, AggResult](

Then it returns:

Batch: 0
|day_start          |group |cumsum_by_day|
|2020-01-01 01:00:00|Group2|-1           |
|2020-01-02 01:00:00|Group2|-6           |
|2020-01-01 01:00:00|Group1|1            |
|2020-01-02 01:00:00|Group1|6            |

Batch: 1
|day_start          |group |cumsum_by_day|
|2020-01-03 01:00:00|Group2|-15          |
|2020-01-03 01:00:00|Group1|15           |

processEventGroup is the key function which contains all the technical stuff: cumulative aggregative and output after each day.

def processEventGroup(
    group: String,
    events: Iterator[(Instant, String, Int)],
    state: GroupState[IntermediateState]
) = {
  def mergeState(events: List[Event]): Iterator[AggResult] = {
    // Initialize the aggregation without previous state or a new one
    var (acc_value, acc_timestamp) = state.getOption
      .map(s => (s.agg_value, s.last_timestamp))
      .getOrElse((0, Instant.EPOCH))

    val agg_results = events.flatMap { e =>
      // create an daily report if the new event occurs on another day
      val intermediate_day_result =
        if ( // not same day
          acc_timestamp != Instant.EPOCH &&
          truncateDay(e.timestamp) > truncateDay(acc_timestamp)
        ) {
          Seq(AggResult(truncateDay(acc_timestamp), group, acc_value))
        } else {
      // apply the aggregation as usual (`sum` on value, `last` on timestamp)
      acc_value += e.value
      acc_timestamp = e.timestamp

    // if a timeout occurs before next events data in the same group, 
    // a daily report will be generated
    state.setTimeoutTimestamp(state.getCurrentWatermarkMs, "1 day")
    // save the current aggregated value as state storage
    state.update(IntermediateState(acc_timestamp, group, acc_value))

  if (state.hasTimedOut && events.isEmpty) {
    // generate a daily report on timeout
      .map(agg_result =>
        AggResult(truncateDay(agg_result.last_timestamp), group, agg_result.agg_value)
  } else {
    // a list daily report may be generated while processing the new events
    mergeState(events.map { case (timestamp, group, value) =>
      Event(timestamp, group, value)

processEventGroup will be called at each batch once per group . The state is managed by GroupState (the state should just be serializable).

For completness, here the missing elements:

def truncateDay(ts: Instant): Instant = {

case class Event(timestamp: Instant, group: String, value: Int)

case class IntermediateState(last_timestamp: Instant, group: String, agg_value: Int)

case class AggResult(day_start: Instant, group: String, cumsum_by_day: Int)

(code available here )

