繁体   English   中英

如何在Flink中为WindowFunction将fold()转换为AggregateFunction?

[英]How to convert fold() to an AggregateFunction in Flink for WindowFunction?

我试图通过删除不推荐使用的类将Flink的旧Yahoo流基准测试版本转换为新版本。

我现在停留在将不赞成使用的fold()转换为aggregate()。 我无法将fold的现有参数映射到聚合参数。

//old version using fold
 val windowedCounts = windowedEvents.fold(new WindowedCount(null, "", 0, new java.sql.Timestamp(0L)),
          (acc: WindowedCount, r: (String, String, Timestamp)) => {
            val lastUpdate = if (acc.lastUpdate.getTime < r._3.getTime) r._3 else acc.lastUpdate
            acc.count += 1
            acc.lastUpdate = lastUpdate
            acc
          },
          (key: Tuple, window: TimeWindow, input: Iterable[WindowedCount], out: Collector[WindowedCount]) => {
            val windowedCount = input.iterator.next()
            println(windowedCount.lastUpdate)
            out.collect(new WindowedCount(new java.sql.Timestamp(window.getStart), key.getField(0), windowedCount.count, windowedCount.lastUpdate))
            //out.collect(new WindowedCount(new java.sql.Timestamp(window.getStart), key.getField(0), windowedCount.count, windowedCount.lastUpdate))
          }
        )

val windowedCounts = windowedEvents.aggregate(new CountAggregate)

我想通过扩展AggregateFunction类(类似)来创建CountAggregate类:

class CountAggregate extends AggregateFunction[(String, String, Timestamp), WindowedCount, Collector[WindowedCount]] {
    override def createAccumulator() = WindowedCount(null, "", 0, new java.sql.Timestamp(0L))

    override def accumulate(acc: WindowedCount, r: (String, String, Timestamp)): WindowedCount = {
      val lastUpdate = if (acc.lastUpdate.getTime < r._3.getTime) r._3 else acc.lastUpdate
      acc.count += 1
      acc.lastUpdate = lastUpdate
      acc
          }

    override def getValue (acc: WindowedCount)  = { (key: Tuple, window: TimeWindow, input: Iterable[WindowedCount], out: Collector[WindowedCount]) =>
      val windowedCount = input.iterator.next()
      println(windowedCount.lastUpdate)
      out.collect(new WindowedCount(new java.sql.Timestamp(window.getStart), key.getField(0), windowedCount.count, windowedCount.lastUpdate))
    }

重写CountAggregate类的任何帮助将不胜感激。

您需要指定一个AggregateFunction以及一个ProcessWindowFunction来执行最后的getValue步骤:

val windowedCounts = windowedEvents.aggregate(
      new CountAggregate(),
      new WindowAggregateFunction())

class CountAggregate extends AggregateFunction[(String, String, Timestamp), WindowedCount, WindowedCount] {
  override def createAccumulator() = WindowedCount(null, "", 0, new java.sql.Timestamp(0L))

  override def add(value: (String, String, Timestamp), acc: WindowedCount): WindowedCount = {
    val lastUpdate = if (acc.lastUpdate.getTime < value._3.getTime) value._3 else acc.lastUpdate
    WindowedCount(null, "", acc.count + 1, lastUpdate)
  }

  override def getResult(accumulator: WindowedCount): WindowedCount = {
    accumulator
  }

  override def merge(a: WindowedCount, b: WindowedCount): WindowedCount = {
    WindowedCount(null, "", a.count + b.count, if (a.lastUpdate.getTime < b.lastUpdate.getTime) b.lastUpdate else a.lastUpdate)
  }
}

class WindowAggregateFunction extends ProcessWindowFunction[WindowedCount, WindowedCount, Tuple, TimeWindow]() {
  override def process(key: Tuple, context: Context, elements: Iterable[WindowedCount], out: Collector[WindowedCount]): Unit = {
    val windowedCount = elements.iterator.next()
    out.collect(WindowedCount(new java.sql.Timestamp(context.window.getStart), key.getField(0), windowedCount.count, windowedCount.lastUpdate))
  }
}

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM