[英]Calculate value based on value from same column of the previous row in spark
I have an issue where I have to calculate a column using a formula that uses the value from the calculation done in the previous row.我有一个问题,我必须使用一个公式来计算一列,该公式使用前一行中计算的值。
I am unable to figure it out using withColumn
API.我无法使用withColumn
API 弄清楚。
I need to calculate a new column, using the formula:我需要使用以下公式计算一个新列:
MovingRate = MonthlyRate + (0.7 * MovingRatePrevious)
... where the MovingRatePrevious
is the MovingRate
of the prior row. ... 其中MovingRatePrevious
是前一行的MovingRate
。
For month 1, I have the value so I do not need to re-calculate that but I need that value to be able to calculate the subsequent rows.对于第 1 个月,我有这个值,所以我不需要重新计算它,但我需要该值才能计算后续行。 I need to partition by Type.我需要按类型分区。
This is my original dataset:这是我的原始数据集:
Desired results in MovingRate column: MovingRate 列中的所需结果:
Given the nature of the requirement that each moving rate is recursively computed from the previous rate, the column-oriented DataFrame API won't shine especially if the dataset is huge.考虑到每个移动速率是从前一个速率递归计算的要求的性质,面向列的 DataFrame API 不会发光,尤其是在数据集很大的情况下。
That said, if the dataset isn't large, one approach would be to make Spark recalculate the moving rates row-wise via a UDF, with a Window-partitioned rate list as its input:也就是说,如果数据集不大,一种方法是让 Spark 通过 UDF 逐行重新计算移动速率,并将窗口分区速率列表作为其输入:
import org.apache.spark.sql.expressions.Window
val df = Seq(
(1, "blue", 0.4, Some(0.33)),
(2, "blue", 0.3, None),
(3, "blue", 0.7, None),
(4, "blue", 0.9, None),
(1, "red", 0.5, Some(0.2)),
(2, "red", 0.6, None),
(3, "red", 0.8, None)
).toDF("Month", "Type", "MonthlyRate", "MovingRate")
val win = Window.partitionBy("Type").orderBy("Month").
rowsBetween(Window.unboundedPreceding, 0)
def movingRate(factor: Double) = udf( (initRate: Double, monthlyRates: Seq[Double]) =>
monthlyRates.tail.foldLeft(initRate)( _ * factor + _ )
)
df.
withColumn("MovingRate", when($"Month" === 1, $"MovingRate").otherwise(
movingRate(0.7)(last($"MovingRate", ignoreNulls=true).over(win), collect_list($"MonthlyRate").over(win))
)).
show
// +-----+----+-----------+------------------+
// |Month|Type|MonthlyRate| MovingRate|
// +-----+----+-----------+------------------+
// | 1| red| 0.5| 0.2|
// | 2| red| 0.6| 0.74|
// | 3| red| 0.8| 1.318|
// | 1|blue| 0.4| 0.33|
// | 2|blue| 0.3|0.5309999999999999|
// | 3|blue| 0.7|1.0716999999999999|
// | 4|blue| 0.9|1.6501899999999998|
// +-----+----+-----------+------------------+
Altough its possible to do with Widow Functions (See @Leo C's answer), I bet its more performant to aggregate once per Type
using a groupBy
.尽管可以使用寡妇函数(请参阅@Leo C 的答案),但我敢打赌,使用groupBy
为每个Type
聚合一次会更高效。 Then, explode the results of the UDF to get all rows back:然后,分解 UDF 的结果以获取所有行:
val df = Seq(
(1, "blue", 0.4, Some(0.33)),
(2, "blue", 0.3, None),
(3, "blue", 0.7, None),
(4, "blue", 0.9, None)
)
.toDF("Month", "Type", "MonthlyRate", "MovingRate")
// this udf produces an Seq of Tuple3 (Month, MonthlyRate, MovingRate)
val calcMovingRate = udf((startRate:Double,rates:Seq[Row]) => rates.tail
.scanLeft((rates.head.getInt(0),startRate,startRate))((acc,curr) => (curr.getInt(0),curr.getDouble(1),acc._3+0.7*curr.getDouble(1)))
)
df
.groupBy($"Type")
.agg(
first($"MovingRate",ignoreNulls=true).as("startRate"),
collect_list(struct($"Month",$"MonthlyRate")).as("rates")
)
.select($"Type",explode(calcMovingRate($"startRate",$"rates")).as("movingRates"))
.select($"Type",$"movingRates._1".as("Month"),$"movingRates._2".as("MonthlyRate"),$"movingRates._3".as("MovingRate"))
.show()
gives:给出:
+----+-----+-----------+------------------+
|Type|Month|MonthlyRate| MovingRate|
+----+-----+-----------+------------------+
|blue| 1| 0.33| 0.33|
|blue| 2| 0.3| 0.54|
|blue| 3| 0.7| 1.03|
|blue| 4| 0.9|1.6600000000000001|
+----+-----+-----------+------------------+
What you are trying to do is compute a recursive formula that looks like:您要做的是计算一个递归公式,如下所示:
x[i] = y[i] + 0.7 * x[i-1]
where x[i]
is your MovingRate
at row i
and y[i]
your MonthlyRate
at row i
.其中x[i]
是您在第i
行的MovingRate
和y[i]
您在第i
行的MonthlyRate
。
The problem is that this is a purely sequential formula.问题是这是一个纯粹的顺序公式。 Each row needs the result of the previous one which in turn needs the result of the one before.每一行都需要前一行的结果,而前一行的结果又需要前一行的结果。 Spark is a parallel computation engine and it is going to be hard to use it to speed up a calculation that cannot really be parallelized. Spark 是一个并行计算引擎,很难使用它来加速无法真正并行化的计算。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.