簡體   English   中英

根據spark中上一行的同一列的值計算值

[英]Calculate value based on value from same column of the previous row in spark

我有一個問題,我必須使用一個公式來計算一列,該公式使用前一行中計算的值。

我無法使用withColumn API 弄清楚。

我需要使用以下公式計算一個新列:

MovingRate = MonthlyRate + (0.7 * MovingRatePrevious)

... 其中MovingRatePrevious是前一行的MovingRate

對於第 1 個月,我有這個值,所以我不需要重新計算它,但我需要該值才能計算后續行。 我需要按類型分區。

這是我的原始數據集:

在此處輸入圖像描述

MovingRate 列中的所需結果:

在此處輸入圖像描述

考慮到每個移動速率是從前一個速率遞歸計算的要求的性質,面向列的 DataFrame API 不會發光,尤其是在數據集很大的情況下。

也就是說,如果數據集不大,一種方法是讓 Spark 通過 UDF 逐行重新計算移動速率,並將窗口分區速率列表作為其輸入:

import org.apache.spark.sql.expressions.Window

val df = Seq(
  (1, "blue", 0.4, Some(0.33)),
  (2, "blue", 0.3, None),
  (3, "blue", 0.7, None),
  (4, "blue", 0.9, None),
  (1, "red", 0.5, Some(0.2)),
  (2, "red", 0.6, None),
  (3, "red", 0.8, None)
).toDF("Month", "Type", "MonthlyRate", "MovingRate")

val win = Window.partitionBy("Type").orderBy("Month").
  rowsBetween(Window.unboundedPreceding, 0)

def movingRate(factor: Double) = udf( (initRate: Double, monthlyRates: Seq[Double]) =>
  monthlyRates.tail.foldLeft(initRate)( _ * factor + _ )
)

df.
  withColumn("MovingRate", when($"Month" === 1, $"MovingRate").otherwise(
    movingRate(0.7)(last($"MovingRate", ignoreNulls=true).over(win), collect_list($"MonthlyRate").over(win))
  )).
  show
// +-----+----+-----------+------------------+
// |Month|Type|MonthlyRate|        MovingRate|
// +-----+----+-----------+------------------+
// |    1| red|        0.5|               0.2|
// |    2| red|        0.6|              0.74|
// |    3| red|        0.8|             1.318|
// |    1|blue|        0.4|              0.33|
// |    2|blue|        0.3|0.5309999999999999|
// |    3|blue|        0.7|1.0716999999999999|
// |    4|blue|        0.9|1.6501899999999998|
// +-----+----+-----------+------------------+

盡管可以使用寡婦函數(請參閱@Leo C 的答案),但我敢打賭,使用groupBy為每個Type聚合一次會更高效。 然后,分解 UDF 的結果以獲取所有行:

val df = Seq(
  (1, "blue", 0.4, Some(0.33)),
  (2, "blue", 0.3, None),
  (3, "blue", 0.7, None),
  (4, "blue", 0.9, None)
)
.toDF("Month", "Type", "MonthlyRate", "MovingRate")

// this udf produces an Seq of Tuple3 (Month, MonthlyRate, MovingRate)
val calcMovingRate = udf((startRate:Double,rates:Seq[Row]) => rates.tail
  .scanLeft((rates.head.getInt(0),startRate,startRate))((acc,curr) => (curr.getInt(0),curr.getDouble(1),acc._3+0.7*curr.getDouble(1)))
)

df
  .groupBy($"Type")
  .agg(
    first($"MovingRate",ignoreNulls=true).as("startRate"),
    collect_list(struct($"Month",$"MonthlyRate")).as("rates")
  )
  .select($"Type",explode(calcMovingRate($"startRate",$"rates")).as("movingRates"))
  .select($"Type",$"movingRates._1".as("Month"),$"movingRates._2".as("MonthlyRate"),$"movingRates._3".as("MovingRate"))
  .show()

給出:

+----+-----+-----------+------------------+
|Type|Month|MonthlyRate|        MovingRate|
+----+-----+-----------+------------------+
|blue|    1|       0.33|              0.33|
|blue|    2|        0.3|              0.54|
|blue|    3|        0.7|              1.03|
|blue|    4|        0.9|1.6600000000000001|
+----+-----+-----------+------------------+

您要做的是計算一個遞歸公式,如下所示:

x[i] = y[i] + 0.7 * x[i-1]

其中x[i]是您在第i行的MovingRatey[i]您在第i行的MonthlyRate

問題是這是一個純粹的順序公式。 每一行都需要前一行的結果,而前一行的結果又需要前一行的結果。 Spark 是一個並行計算引擎,很難使用它來加速無法真正並行化的計算。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM