[英]Spark : aggregate values by a list of given years
I'm new to Scala, say I have a dataset:我是 Scala 的新手,说我有一个数据集:
>>> ds.show()
+--------------+-----------------+-------------+
|year |nb_product_sold | system_year |
+--------------+-----------------+-------------+
|2010 | 1 | 2012 |
|2012 | 2 | 2012 |
|2012 | 4 | 2012 |
|2015 | 3 | 2012 |
|2019 | 4 | 2012 |
|2021 | 5 | 2012 |
+--------------+-----------------+-------+
and I have a List<Integer> years = {1, 3, 8}
, which means the x
year after system_year
year.我有一个List<Integer> years = {1, 3, 8}
,这意味着system_year
年后的x
年。 The goal is to calculate the number of total sold products for each year
after system_year
.目标是计算system_year
之后year
的总销售产品数量。
In other words, I have to calculate the total sold products for year 2013, 2015, 2020.换句话说,我必须计算 2013 年、2015 年、2020 年的总销售产品。
The output dataset should be like this: output 数据集应该是这样的:
+-------+-----------------------+
| year | total_product_sold |
+-------+-----------------------+
| 1 | 6 | -> 2012 - 2013 6 products sold
| 3 | 9 | -> 2012 - 2015 9 products sold
| 8 | 13 | -> 2012 - 2020 13 products sold
+-------+-----------------------+
I want to know how to do this in scala?我想知道如何在 scala 中做到这一点? Should I use groupBy()
in this case?在这种情况下我应该使用groupBy()
吗?
You could have used a groupby case/when if the year ranges didn't overlap.如果年份范围不重叠,您可以使用 groupby case/when。 But here you'll need to do a groupby for each year and then union the 3 grouped dataframes:但是在这里,您需要每年进行一次 groupby,然后合并 3 个分组数据框:
val years = List(1, 3, 8)
val result = years.map{ y =>
df.filter($"year".between($"system_year", $"system_year" + y))
.groupBy(lit(y).as("year"))
.agg(sum($"nb_product_sold").as("total_product_sold"))
}.reduce(_ union _)
result.show
//+----+------------------+
//|year|total_product_sold|
//+----+------------------+
//| 1| 6|
//| 3| 9|
//| 8| 13|
//+----+------------------+
There might be multiple ways of doing things and more efficient than what I am showing you but it works for your use case.可能有多种做事方式,并且比我向您展示的方式更有效,但它适用于您的用例。
//Sample Data
val df = Seq((2010,1,2012),(2012,2,2012),(2012,4,2012),(2015,3,2012),(2019,4,2012),(2021,5,2012)).toDF("year","nb_product_sold","system_year")
//taking the difference of the years from system year
val df1 = df.withColumn("Difference",$"year" - $"system_year")
//getting the running total for all years present in the dataframe by partitioning
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val w = Window.partitionBy("year").orderBy("year")
val df2 = df1.withColumn("runningsum", sum("nb_product_sold").over(w)).withColumn("yearlist",lit(0)).dropDuplicates("year","system_year","Difference")
//creating Years list
val years = List(1, 3, 8)
//creating a dataframe with total count for each year and union of all the dataframe and removing duplicates.
var df3= spark.createDataFrame(sc.emptyRDD[Row], df2.schema)
for (year <- years){
val innerdf = df2.filter($"Difference" >= year -1 && $"Difference" <= year).withColumn("yearlist",lit(year))
df3 = df3.union(innerdf)
}
//again doing partition by system date and doing the sum for all the years as per requirement
val w1 = Window.partitionBy("system_year").orderBy("year")
val finaldf = df3.withColumn("total_product_sold", sum("runningsum").over(w1)).select("yearlist","total_product_sold")
you can see the output as below:您可以看到 output 如下:
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.