简体   繁体   English

在Spark Scala中将当前行中的前一行值求和

[英]sum previous row value in current row in spark scala

I am trying to adjust one of column values based on value in some other data frame. 我正在尝试根据其他数据框中的值调整列值之一。 While doing this, if left over amount more then I need to carry forward to next row and calculate final amount. 这样做时,如果剩余量更多,则需要结转到下一行并计算最终金额。

During this operation, I am not able to hold of previous row left over amount to next row operation. 在此操作期间,我无法保留上一行剩余的金额到下一行操作。 I tried using lag window function and by taking running totals options but those are not working as expected. 我尝试使用滞后窗口功能并采用运行总计选项,但这些选项未按预期工作。

I am work with Scala. 我正在与Scala合作。 Here is input data 这是输入数据

val consumption = sc.parallelize(Seq((20180101, 600), (20180201, 900),(20180301, 400),(20180401, 600),(20180501, 1000),(20180601, 1900),(20180701, 500),(20180801, 100),(20180901, 500))).toDF("Month","Usage")
consumption.show()
+--------+-----+
|   Month|Usage|
+--------+-----+
|20180101|  600|
|20180201|  900|
|20180301|  400|
|20180401|  600|
|20180501| 1000|
|20180601| 1900|
|20180701|  500|
|20180801|  100|
|20180901|  500|
+--------+-----+
val promo = sc.parallelize(Seq((20180101, 1000),(20180201, 100),(20180401, 3000))).toDF("PromoEffectiveMonth","promoAmount")
promo.show()
+-------------------+-----------+
|PromoEffectiveMonth|promoAmount|
+-------------------+-----------+
|           20180101|       1000|
|           20180201|        100|
|           20180401|       3000|
+-------------------+-----------+

expected result: 预期结果:

val finaldf = sc.parallelize(Seq((20180101,600,400,600),(20180201,900,0,400),(20180301,400,0,0),(20180401,600,2400,600),(20180501,1000,1400,1000),(20180601,1900,0,500),(20180701,500,0,0),(20180801,100,0,0),(20180901,500,0,0))).toDF("Month","Usage","LeftOverPromoAmt","AdjustedUsage")
finaldf.show()
+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

The logic what I am applying is based on Month and PromoEffective join, need to apply promo amount on consumption usage column till promo amount become zero. 我要应用的逻辑基于“月”和“促销有效联接”,需要在消费使用列上应用促销金额,直到促销金额变为零。

eg: in Jan'18 month, promoamount is 1000, after deducting from usage (600), the left over promo amt is 400 and adj usage is 600. the left over over 400 will be considered for next month and there promo amt for Feb then final promo amount available is 500. here usage is more when compare to usage. 例如:在1月18日,促销金额为1000,从使用量(600)中减去后,剩余的促销金额为400,调整后的使用量为600。剩余的400将会在下个月考虑,2月的促销金额那么最终的促销金额为500。与使用量相比,此处的使用量更大。

So left over promo amount is zero and adjust usage is 400 (900 - 500). 因此剩余的促销金额为零,调整使用量为400(900-500)。

First of all, you need to perform a left_outer join so that for each row you have its corresponding promotion. 首先,您需要执行left_outer连接,以便对每一行都有相应的提升。 The join operation is performed by means of the fields Month and PromoEffectiveMonth from the datasets Consumption and promo , respectively. 分别通过数据集Consumptionpromo MonthPromoEffectiveMonth字段执行PromoEffectiveMonth Note also that I have created a new column, Timestamp . 还要注意,我已经创建了一个新列Timestamp It has been created by using the Spark SQL unix_timestamp function. 它是通过使用Spark SQL unix_timestamp函数创建的。 It will be used to sort the dataset by date. 它将用于按日期对数据集进行排序。

val ds = consumption
    .join(promo, consumption.col("Month") === promo.col("PromoEffectiveMonth"), "left_outer")
    .select("UserID", "Month", "Usage", "promoAmount")
    .withColumn("Timestamp", unix_timestamp($"Month".cast("string"), "yyyyMMdd").cast(TimestampType))

This is the result of these operations. 这是这些操作的结果。

+--------+-----+-----------+-------------------+
|   Month|Usage|promoAmount|          Timestamp|
+--------+-----+-----------+-------------------+
|20180301|  400|       null|2018-03-01 00:00:00|
|20180701|  500|       null|2018-07-01 00:00:00|
|20180901|  500|       null|2018-09-01 00:00:00|
|20180101|  600|       1000|2018-01-01 00:00:00|
|20180801|  100|       null|2018-08-01 00:00:00|
|20180501| 1000|       null|2018-05-01 00:00:00|
|20180201|  900|        100|2018-02-01 00:00:00|
|20180601| 1900|       null|2018-06-01 00:00:00|
|20180401|  600|       3000|2018-04-01 00:00:00|
+--------+-----+-----------+-------------------+

Next, you have to create a Window . 接下来,您必须创建一个Window Window functions are used to perform calculations over a group of records by using some criteria (more info on this here ). 窗口函数用于通过使用某些条件对一组记录进行计算(有关更多信息,请参见此处 )。 In our case, the criteria is to sort each group by Timestamp . 在我们的例子中,标准是按Timestamp对每个组进行排序。

 val window = Window.orderBy("Timestamp")

Okay, now comes the hardest part. 好的,现在是最困难的部分。 You need to create a User Defined Aggregate Function . 您需要创建一个用户定义的聚合函数 In this function, each group will be processed according to a custom operation, and it will enable you to process each row by taking into account the value of the previous one. 在此功能中,将根据自定义操作对每个组进行处理,并使您可以通过考虑上一行的值来处理每一行。

  class CalculatePromos extends UserDefinedAggregateFunction {
    // Input schema for this UserDefinedAggregateFunction
    override def inputSchema: StructType =
      StructType(
        StructField("Usage", LongType) ::
        StructField("promoAmount", LongType) :: Nil)

    // Schema for the parameters that will be used internally to buffer temporary values
    override def bufferSchema: StructType = StructType(
        StructField("AdjustedUsage", LongType) ::
        StructField("LeftOverPromoAmt", LongType) :: Nil
    )

    // The data type returned by this UserDefinedAggregateFunction.
    // In this case, it will return an StructType with two fields: AdjustedUsage and LeftOverPromoAmt
    override def dataType: DataType = StructType(Seq(StructField("AdjustedUsage", LongType), StructField("LeftOverPromoAmt", LongType)))

    // Whether this UDAF is deterministic or not. In this case, it is
    override def deterministic: Boolean = true

    // Initial values for the temporary values declared above
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = 0L
      buffer(1) = 0L
    }

    // In this function, the values associated to the buffer schema are updated
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

      val promoAmount = if(input.isNullAt(1)) 0L else input.getLong(1)
      val leftOverAmount = buffer.getLong(1)
      val usage = input.getLong(0)
      val currentPromo = leftOverAmount + promoAmount

      if(usage < currentPromo) {
        buffer(0) = usage
        buffer(1) = currentPromo - usage
      } else {
        if(currentPromo == 0)
          buffer(0) = 0L
        else
          buffer(0) = usage - currentPromo
        buffer(1) = 0L
      }
    }

    // Function used to merge two objects. In this case, it is not necessary to define this method since
    // the whole logic has been implemented in update
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}

    // It is what you will return. In this case, a tuple of the buffered values which rerpesent AdjustedUsage and LeftOverPromoAmt
    override def evaluate(buffer: Row): Any = {
      (buffer.getLong(0), buffer.getLong(1))
    }

  }

Basically, it creates a function which can be used in Spark SQL that receives two columns ( Usage and promoAmount , as specified in the method inputSchema ), and returns a new column with two subcolums ( AdjustedUsage and LeftOverPromAmt , as defined in the method dataType ). 基本上,它创建其可以在接收两列(火花SQL使用的功能UsagepromoAmount ,如在方法中指定inputSchema ),以及具有两个subcolums返回一个新的柱( AdjustedUsageLeftOverPromAmt ,如在方法中定义dataType ) 。 The method bufferSchema enables you to create temporary value to be used to support the operations. 使用bufferSchema方法,您可以创建用于支持操作的临时值。 In this case, I have defined AdjustedUsage and LeftOverPromoAmt . 在这种情况下,我定义了AdjustedUsageLeftOverPromoAmt

The logic you are applying is implemented in the method update . 您要应用逻辑在方法update Basically, that takes the values previously calculated and updates them. 基本上,它将采用先前计算的值并进行更新。 The argument buffer contains the temporary values defined in bufferSchema , and input keeps the value of the row that is being processed in that moment. 参数buffer包含在bufferSchema定义的临时值, input保留该时刻正在处理的行的值。 Finally, evaluate returns a tuple object containing the result of the operations for each row, in this case, the temporary values defined in bufferSchema and updated in the method update . 最后, evaluate返回一个元组对象,其中包含每一行的操作结果,在这种情况下,是在bufferSchema定义并在方法update的临时值。

The next step is to create a variable by instantiating the class CalculatePromos . 下一步是通过实例化CalculatePromos类来创建变量。

val calculatePromos = new CalculatePromos

Finally, you have to apply the User Defined Aggregate Function calculatePromos by using the method withColumn of the dataset. 最后,您必须使用数据集的withColumn方法来应用用户定义的聚合函数calculatePromos Note that you have to pass it the input columns ( Usage and promoAmount ) and apply the window by using the method over. 请注意,您必须将输入列( UsagepromoAmount )传递给它,然后通过使用方法来应用窗口。

ds
  .withColumn("output", calculatePromos($"Usage", $"promoAmount").over(window))
  .select($"Month", $"Usage", $"output.LeftOverPromoAmt".as("LeftOverPromoAmt"), $"output.AdjustedUsage".as("AdjustedUsage"))

This is the result: 结果如下:

+--------+-----+----------------+-------------+
|   Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101|  600|             400|          600|
|20180201|  900|               0|          400|
|20180301|  400|               0|            0|
|20180401|  600|            2400|          600|
|20180501| 1000|            1400|         1000|
|20180601| 1900|               0|          500|
|20180701|  500|               0|            0|
|20180801|  100|               0|            0|
|20180901|  500|               0|            0|
+--------+-----+----------------+-------------+

Hope it helps. 希望能帮助到你。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM