I have to compute a cumulative sum on a value
column by group
from the beginning of the time series with a daily output.
If I do with a batch, it should be something like this:
val columns = Seq("timestamp", "group", "value")
val data = List(
(Instant.parse("2020-01-01T00:00:00Z"), "Group1", 0),
(Instant.parse("2020-01-01T00:00:00Z"), "Group2", 0),
(Instant.parse("2020-01-01T12:00:00Z"), "Group1", 1),
(Instant.parse("2020-01-01T12:00:00Z"), "Group2", -1),
(Instant.parse("2020-01-02T00:00:00Z"), "Group1", 2),
(Instant.parse("2020-01-02T00:00:00Z"), "Group2", -2),
(Instant.parse("2020-01-02T12:00:00Z"), "Group1", 3),
(Instant.parse("2020-01-02T12:00:00Z"), "Group2", -3),
)
val df = spark
.createDataFrame(data)
.toDF(columns: _*)
// defines a window from the beginning by `group`
val event_window = Window
.partitionBy(col("group"))
.orderBy(col("timestamp"))
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val computed_df = df
.withColumn(
"cumsum",
functions
.sum('value)
.over(event_window) // apply the aggregation on a window from the beginning
)
.groupBy(window($"timestamp", "1 day"), $"group")
.agg(functions.last("cumsum").as("cumsum_by_day")) // display the last value for each day
computed_df.show(truncate = false)
and the output is
+------------------------------------------+------+-------------+
|window |group |cumsum_by_day|
+------------------------------------------+------+-------------+
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group1| 1 |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group1| 6 |
|{2020-01-01 01:00:00, 2020-01-02 01:00:00}|Group2|-1 |
|{2020-01-02 01:00:00, 2020-01-03 01:00:00}|Group2|-6 |
+------------------------------------------+------+-------------+
The result is perfectly fine.
However, in my case, the data source is not an existing dataset but a stream and I didn't find any solution to apply the aggregation from the beginning of the stream, not on a sliding window.
The closest code I can do is:
// MemoryStream to reproduce locally the issue
implicit val sqlCtx: SQLContext = spark.sqlContext
val memoryStream = MemoryStream[(Instant, String, Int)]
memoryStream.addData(data)
val df = memoryStream
.toDF()
.toDF(columns: _*)
val computed_df = df
.groupBy(window($"timestamp", "1 day"), $"group")
.agg(functions.sum('value).as("agg"))
computed_df.writeStream
.option("truncate", value = false)
.format("console")
.outputMode("complete")
.start()
.processAllAvailable()
}
It produces an aggregation for each day but not from the beginning of the stream.
If I try to add something like .over(event_window)
(like in batch), it compiles but fails at runtime.
How can we apply an aggregation function from the beginning of a stream?
Here aGitHub repository with all the context to run that code.
I didn't find any solution using the high-level functions. For example, it is not possible to add another groupBy
over the main aggregation agg(functions.sum('value).as("agg"), functions.last('timestamp).as("ts")
to get the daily report.
After many experiments, I switched to the low level functions. The most polyvalent function seems to be flatMapGroupsWithState
.
// same `events` Dataframe as before
// Accumulate value by group and report every day
val computed_df = events
.withWatermark("timestamp", "0 second") // watermarking required to use GroupStateTimeout
.as[(Instant, String, Int)]
.groupByKey(event => event._2)
.flatMapGroupsWithState[IntermediateState, AggResult](
OutputMode.Append(),
GroupStateTimeout.EventTimeTimeout
)(processEventGroup)
Then it returns:
-------------------------------------------
Batch: 0
-------------------------------------------
+-------------------+------+-------------+
|day_start |group |cumsum_by_day|
+-------------------+------+-------------+
|2020-01-01 01:00:00|Group2|-1 |
|2020-01-02 01:00:00|Group2|-6 |
|2020-01-01 01:00:00|Group1|1 |
|2020-01-02 01:00:00|Group1|6 |
+-------------------+------+-------------+
-------------------------------------------
Batch: 1
-------------------------------------------
+-------------------+------+-------------+
|day_start |group |cumsum_by_day|
+-------------------+------+-------------+
|2020-01-03 01:00:00|Group2|-15 |
|2020-01-03 01:00:00|Group1|15 |
+-------------------+------+-------------+
processEventGroup
is the key function which contains all the technical stuff: cumulative aggregative and output after each day.
def processEventGroup(
group: String,
events: Iterator[(Instant, String, Int)],
state: GroupState[IntermediateState]
) = {
def mergeState(events: List[Event]): Iterator[AggResult] = {
// Initialize the aggregation without previous state or a new one
var (acc_value, acc_timestamp) = state.getOption
.map(s => (s.agg_value, s.last_timestamp))
.getOrElse((0, Instant.EPOCH))
val agg_results = events.flatMap { e =>
// create an daily report if the new event occurs on another day
val intermediate_day_result =
if ( // not same day
acc_timestamp != Instant.EPOCH &&
truncateDay(e.timestamp) > truncateDay(acc_timestamp)
) {
Seq(AggResult(truncateDay(acc_timestamp), group, acc_value))
} else {
Seq.empty
}
// apply the aggregation as usual (`sum` on value, `last` on timestamp)
acc_value += e.value
acc_timestamp = e.timestamp
intermediate_day_result
}
// if a timeout occurs before next events data in the same group,
// a daily report will be generated
state.setTimeoutTimestamp(state.getCurrentWatermarkMs, "1 day")
// save the current aggregated value as state storage
state.update(IntermediateState(acc_timestamp, group, acc_value))
agg_results.iterator
}
if (state.hasTimedOut && events.isEmpty) {
// generate a daily report on timeout
state.getOption
.map(agg_result =>
AggResult(truncateDay(agg_result.last_timestamp), group, agg_result.agg_value)
)
.iterator
} else {
// a list daily report may be generated while processing the new events
mergeState(events.map { case (timestamp, group, value) =>
Event(timestamp, group, value)
}.toList)
}
}
processEventGroup
will be called at each batch once per group
. The state is managed by GroupState
(the state should just be serializable).
For completness, here the missing elements:
def truncateDay(ts: Instant): Instant = {
ts.truncatedTo(ChronoUnit.DAYS)
}
case class Event(timestamp: Instant, group: String, value: Int)
case class IntermediateState(last_timestamp: Instant, group: String, agg_value: Int)
case class AggResult(day_start: Instant, group: String, cumsum_by_day: Int)
(code available here )
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.