简体   繁体   中英

Cumulative sum from the beginning of a stream in Spark

I have to compute a cumulative sum on a value column by group from the beginning of the time series with a daily output.

If I do with a batch, it should be something like this:

val columns = Seq("timestamp", "group", "value")
val data = List(
  (Instant.parse("2020-01-01T00:00:00Z"), "Group1",  0),
  (Instant.parse("2020-01-01T00:00:00Z"), "Group2",  0),
  (Instant.parse("2020-01-01T12:00:00Z"), "Group1",  1),
  (Instant.parse("2020-01-01T12:00:00Z"), "Group2", -1),
  (Instant.parse("2020-01-02T00:00:00Z"), "Group1",  2),
  (Instant.parse("2020-01-02T00:00:00Z"), "Group2", -2),
  (Instant.parse("2020-01-02T12:00:00Z"), "Group1",  3),
  (Instant.parse("2020-01-02T12:00:00Z"), "Group2", -3),
)

val df = spark
  .createDataFrame(data)
  .toDF(columns: _*)

// defines a window from the beginning by `group`
val event_window = Window
  .partitionBy(col("group"))
  .orderBy(col("timestamp"))
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

val computed_df = df
  .withColumn(
    "cumsum",
    functions
      .sum('value)
      .over(event_window) // apply the aggregation on a window from the beginning
  )
  .groupBy(window($"timestamp", "1 day"), $"group")
  .agg(functions.last("cumsum").as("cumsum_by_day")) // display the last value for each day

computed_df.show(truncate = false)

and the output is

+------------------------------------------+------+-------------+
|window                                    |group |cumsum_by_day|
+------------------------------------------+------+-------------+
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group1| 1           |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group1| 6           |
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group2|-1           |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group2|-6           |
+------------------------------------------+------+-------------+

The result is perfectly fine.

However, in my case, the data source is not an existing dataset but a stream and I didn't find any solution to apply the aggregation from the beginning of the stream, not on a sliding window.

The closest code I can do is:

// MemoryStream to reproduce locally the issue
implicit val sqlCtx: SQLContext = spark.sqlContext
val memoryStream = MemoryStream[(Instant, String, Int)]
memoryStream.addData(data)
val df = memoryStream
  .toDF()
  .toDF(columns: _*)

val computed_df = df
  .groupBy(window($"timestamp", "1 day"), $"group")
  .agg(functions.sum('value).as("agg"))

computed_df.writeStream
  .option("truncate", value = false)
  .format("console")
  .outputMode("complete")
  .start()
  .processAllAvailable()
}

It produces an aggregation for each day but not from the beginning of the stream.

If I try to add something like .over(event_window) (like in batch), it compiles but fails at runtime.

How can we apply an aggregation function from the beginning of a stream?

Here aGitHub repository with all the context to run that code.

I didn't find any solution using the high-level functions. For example, it is not possible to add another groupBy over the main aggregation agg(functions.sum('value).as("agg"), functions.last('timestamp).as("ts") to get the daily report.

After many experiments, I switched to the low level functions. The most polyvalent function seems to be flatMapGroupsWithState .

// same `events` Dataframe as before 

// Accumulate value by group and report every day
val computed_df = events
  .withWatermark("timestamp", "0 second") // watermarking required to use GroupStateTimeout
  .as[(Instant, String, Int)]
  .groupByKey(event => event._2)
  .flatMapGroupsWithState[IntermediateState, AggResult](
    OutputMode.Append(),
    GroupStateTimeout.EventTimeTimeout
  )(processEventGroup)

Then it returns:

-------------------------------------------
Batch: 0
-------------------------------------------
+-------------------+------+-------------+
|day_start          |group |cumsum_by_day|
+-------------------+------+-------------+
|2020-01-01 01:00:00|Group2|-1           |
|2020-01-02 01:00:00|Group2|-6           |
|2020-01-01 01:00:00|Group1|1            |
|2020-01-02 01:00:00|Group1|6            |
+-------------------+------+-------------+

-------------------------------------------
Batch: 1
-------------------------------------------
+-------------------+------+-------------+
|day_start          |group |cumsum_by_day|
+-------------------+------+-------------+
|2020-01-03 01:00:00|Group2|-15          |
|2020-01-03 01:00:00|Group1|15           |
+-------------------+------+-------------+

processEventGroup is the key function which contains all the technical stuff: cumulative aggregative and output after each day.

def processEventGroup(
    group: String,
    events: Iterator[(Instant, String, Int)],
    state: GroupState[IntermediateState]
) = {
  def mergeState(events: List[Event]): Iterator[AggResult] = {
    // Initialize the aggregation without previous state or a new one
    var (acc_value, acc_timestamp) = state.getOption
      .map(s => (s.agg_value, s.last_timestamp))
      .getOrElse((0, Instant.EPOCH))

    val agg_results = events.flatMap { e =>
      // create an daily report if the new event occurs on another day
      val intermediate_day_result =
        if ( // not same day
          acc_timestamp != Instant.EPOCH &&
          truncateDay(e.timestamp) > truncateDay(acc_timestamp)
        ) {
          Seq(AggResult(truncateDay(acc_timestamp), group, acc_value))
        } else {
          Seq.empty
        }
      // apply the aggregation as usual (`sum` on value, `last` on timestamp)
      acc_value += e.value
      acc_timestamp = e.timestamp
      intermediate_day_result
    }

    // if a timeout occurs before next events data in the same group, 
    // a daily report will be generated
    state.setTimeoutTimestamp(state.getCurrentWatermarkMs, "1 day")
    // save the current aggregated value as state storage
    state.update(IntermediateState(acc_timestamp, group, acc_value))
    agg_results.iterator
  }

  if (state.hasTimedOut && events.isEmpty) {
    // generate a daily report on timeout
    state.getOption
      .map(agg_result =>
        AggResult(truncateDay(agg_result.last_timestamp), group, agg_result.agg_value)
      )
      .iterator
  } else {
    // a list daily report may be generated while processing the new events
    mergeState(events.map { case (timestamp, group, value) =>
      Event(timestamp, group, value)
    }.toList)
  }
}

processEventGroup will be called at each batch once per group . The state is managed by GroupState (the state should just be serializable).

For completness, here the missing elements:

def truncateDay(ts: Instant): Instant = {
  ts.truncatedTo(ChronoUnit.DAYS)
}

case class Event(timestamp: Instant, group: String, value: Int)

case class IntermediateState(last_timestamp: Instant, group: String, agg_value: Int)

case class AggResult(day_start: Instant, group: String, cumsum_by_day: Int)

(code available here )

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM