I am trying to adjust one of column values based on value in some other data frame. While doing this, if left over amount more then I need to carry forward to next row and calculate final amount.
During this operation, I am not able to hold of previous row left over amount to next row operation. I tried using lag window function and by taking running totals options but those are not working as expected.
I am work with Scala. Here is input data
val consumption = sc.parallelize(Seq((20180101, 600), (20180201, 900),(20180301, 400),(20180401, 600),(20180501, 1000),(20180601, 1900),(20180701, 500),(20180801, 100),(20180901, 500))).toDF("Month","Usage")
consumption.show()
+--------+-----+
| Month|Usage|
+--------+-----+
|20180101| 600|
|20180201| 900|
|20180301| 400|
|20180401| 600|
|20180501| 1000|
|20180601| 1900|
|20180701| 500|
|20180801| 100|
|20180901| 500|
+--------+-----+
val promo = sc.parallelize(Seq((20180101, 1000),(20180201, 100),(20180401, 3000))).toDF("PromoEffectiveMonth","promoAmount")
promo.show()
+-------------------+-----------+
|PromoEffectiveMonth|promoAmount|
+-------------------+-----------+
| 20180101| 1000|
| 20180201| 100|
| 20180401| 3000|
+-------------------+-----------+
expected result:
val finaldf = sc.parallelize(Seq((20180101,600,400,600),(20180201,900,0,400),(20180301,400,0,0),(20180401,600,2400,600),(20180501,1000,1400,1000),(20180601,1900,0,500),(20180701,500,0,0),(20180801,100,0,0),(20180901,500,0,0))).toDF("Month","Usage","LeftOverPromoAmt","AdjustedUsage")
finaldf.show()
+--------+-----+----------------+-------------+
| Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101| 600| 400| 600|
|20180201| 900| 0| 400|
|20180301| 400| 0| 0|
|20180401| 600| 2400| 600|
|20180501| 1000| 1400| 1000|
|20180601| 1900| 0| 500|
|20180701| 500| 0| 0|
|20180801| 100| 0| 0|
|20180901| 500| 0| 0|
+--------+-----+----------------+-------------+
The logic what I am applying is based on Month and PromoEffective join, need to apply promo amount on consumption usage column till promo amount become zero.
eg: in Jan'18 month, promoamount is 1000, after deducting from usage (600), the left over promo amt is 400 and adj usage is 600. the left over over 400 will be considered for next month and there promo amt for Feb then final promo amount available is 500. here usage is more when compare to usage.
So left over promo amount is zero and adjust usage is 400 (900 - 500).
First of all, you need to perform a left_outer
join so that for each row you have its corresponding promotion. The join operation is performed by means of the fields Month
and PromoEffectiveMonth
from the datasets Consumption
and promo
, respectively. Note also that I have created a new column, Timestamp
. It has been created by using the Spark SQL unix_timestamp
function. It will be used to sort the dataset by date.
val ds = consumption
.join(promo, consumption.col("Month") === promo.col("PromoEffectiveMonth"), "left_outer")
.select("UserID", "Month", "Usage", "promoAmount")
.withColumn("Timestamp", unix_timestamp($"Month".cast("string"), "yyyyMMdd").cast(TimestampType))
This is the result of these operations.
+--------+-----+-----------+-------------------+
| Month|Usage|promoAmount| Timestamp|
+--------+-----+-----------+-------------------+
|20180301| 400| null|2018-03-01 00:00:00|
|20180701| 500| null|2018-07-01 00:00:00|
|20180901| 500| null|2018-09-01 00:00:00|
|20180101| 600| 1000|2018-01-01 00:00:00|
|20180801| 100| null|2018-08-01 00:00:00|
|20180501| 1000| null|2018-05-01 00:00:00|
|20180201| 900| 100|2018-02-01 00:00:00|
|20180601| 1900| null|2018-06-01 00:00:00|
|20180401| 600| 3000|2018-04-01 00:00:00|
+--------+-----+-----------+-------------------+
Next, you have to create a Window . Window functions are used to perform calculations over a group of records by using some criteria (more info on this here ). In our case, the criteria is to sort each group by Timestamp
.
val window = Window.orderBy("Timestamp")
Okay, now comes the hardest part. You need to create a User Defined Aggregate Function . In this function, each group will be processed according to a custom operation, and it will enable you to process each row by taking into account the value of the previous one.
class CalculatePromos extends UserDefinedAggregateFunction {
// Input schema for this UserDefinedAggregateFunction
override def inputSchema: StructType =
StructType(
StructField("Usage", LongType) ::
StructField("promoAmount", LongType) :: Nil)
// Schema for the parameters that will be used internally to buffer temporary values
override def bufferSchema: StructType = StructType(
StructField("AdjustedUsage", LongType) ::
StructField("LeftOverPromoAmt", LongType) :: Nil
)
// The data type returned by this UserDefinedAggregateFunction.
// In this case, it will return an StructType with two fields: AdjustedUsage and LeftOverPromoAmt
override def dataType: DataType = StructType(Seq(StructField("AdjustedUsage", LongType), StructField("LeftOverPromoAmt", LongType)))
// Whether this UDAF is deterministic or not. In this case, it is
override def deterministic: Boolean = true
// Initial values for the temporary values declared above
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// In this function, the values associated to the buffer schema are updated
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val promoAmount = if(input.isNullAt(1)) 0L else input.getLong(1)
val leftOverAmount = buffer.getLong(1)
val usage = input.getLong(0)
val currentPromo = leftOverAmount + promoAmount
if(usage < currentPromo) {
buffer(0) = usage
buffer(1) = currentPromo - usage
} else {
if(currentPromo == 0)
buffer(0) = 0L
else
buffer(0) = usage - currentPromo
buffer(1) = 0L
}
}
// Function used to merge two objects. In this case, it is not necessary to define this method since
// the whole logic has been implemented in update
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
// It is what you will return. In this case, a tuple of the buffered values which rerpesent AdjustedUsage and LeftOverPromoAmt
override def evaluate(buffer: Row): Any = {
(buffer.getLong(0), buffer.getLong(1))
}
}
Basically, it creates a function which can be used in Spark SQL that receives two columns ( Usage
and promoAmount
, as specified in the method inputSchema
), and returns a new column with two subcolums ( AdjustedUsage
and LeftOverPromAmt
, as defined in the method dataType
). The method bufferSchema
enables you to create temporary value to be used to support the operations. In this case, I have defined AdjustedUsage
and LeftOverPromoAmt
.
The logic you are applying is implemented in the method update
. Basically, that takes the values previously calculated and updates them. The argument buffer
contains the temporary values defined in bufferSchema
, and input
keeps the value of the row that is being processed in that moment. Finally, evaluate
returns a tuple object containing the result of the operations for each row, in this case, the temporary values defined in bufferSchema
and updated in the method update
.
The next step is to create a variable by instantiating the class CalculatePromos
.
val calculatePromos = new CalculatePromos
Finally, you have to apply the User Defined Aggregate Function calculatePromos
by using the method withColumn
of the dataset. Note that you have to pass it the input columns ( Usage
and promoAmount
) and apply the window by using the method over.
ds
.withColumn("output", calculatePromos($"Usage", $"promoAmount").over(window))
.select($"Month", $"Usage", $"output.LeftOverPromoAmt".as("LeftOverPromoAmt"), $"output.AdjustedUsage".as("AdjustedUsage"))
This is the result:
+--------+-----+----------------+-------------+
| Month|Usage|LeftOverPromoAmt|AdjustedUsage|
+--------+-----+----------------+-------------+
|20180101| 600| 400| 600|
|20180201| 900| 0| 400|
|20180301| 400| 0| 0|
|20180401| 600| 2400| 600|
|20180501| 1000| 1400| 1000|
|20180601| 1900| 0| 500|
|20180701| 500| 0| 0|
|20180801| 100| 0| 0|
|20180901| 500| 0| 0|
+--------+-----+----------------+-------------+
Hope it helps.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.