简体   繁体   中英

Discarding first few values while calculating moving average using Spark window function

I'm trying to calculate the quarterly moving average on a column which is grouped by name and I've defined a Spark window function spec as

val wSpec1 = Window.partitionBy("name").orderBy("date").rowsBetween(-2, 0)

My DataFrame looks as below:

在此输入图像描述

+-----+----------+-----------+------------------+
| name|      date|amountSpent|         movingAvg|
+-----+----------+-----------+------------------+
|  Bob|2016-01-01|       25.0|              25.0|
|  Bob|2016-02-02|       25.0|              25.0|
|  Bob|2016-03-03|       25.0|              25.0|
|  Bob|2016-04-04|       29.0|26.333333333333332|
|  Bob|2016-05-06|       27.0|              27.0|
|Alice|2016-01-01|       50.0|              50.0|
|Alice|2016-02-03|       45.0|              47.5|
|Alice|2016-03-04|       55.0|              50.0|
|Alice|2016-04-05|       60.0|53.333333333333336|
|Alice|2016-05-06|       65.0|              60.0|
+-----+----------+-----------+------------------+

The first value that is accurately calculated is highlighted for each name group. I want to replace the first two values with some string, say, NULL . With my limited knowledge of Spark/Scala, I've thought about extracting this column out of the DataFrame and using the patch function in Scala. However, I cannot figure out how to replace the values at intervals like the start of the second name group. Here's my code:

import com.datastax.spark.connector._
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
object Test {

  def main(args: Array[String]) {
    //val sparkSession = SparkSession.builder.master("local").appName("Test").config("spark.cassandra.connection.host", "localhost").config("spark.driver.host", "localhost").getOrCreate()
    val sparkSession = SparkSession.builder.master("local").appName("Test").config("spark.cassandra.connection.host", "localhost").config("spark.driver.host", "localhost").getOrCreate()
    val sc = sparkSession.sparkContext

    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    import sparkSession.implicits._

    val customers = sc.parallelize(List(("Alice", "2016-01-01", 50.00),
      ("Alice", "2016-02-03", 45.00),
      ("Alice", "2016-03-04", 55.00),
      ("Alice", "2016-04-05", 60.00),
      ("Alice", "2016-05-06", 65.00),
      ("Bob", "2016-01-01", 25.00),
      ("Bob", "2016-02-02", 25.00),
      ("Bob", "2016-03-03", 25.00),
      ("Bob", "2016-04-04", 29.00),
      ("Bob", "2016-05-06", 27.00))).toDF("name", "date", "amountSpent")

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._

    // Create a window spec.
    val wSpec1 = Window.partitionBy("name").orderBy("date").rowsBetween(-2, 0)

    val ls=customers.withColumn("movingAvg",avg(customers("amountSpent")).over(wSpec1))
    ls.show()

  }
}

I would suggest to only calculate the average if the window contains exactely 3 rows (ie spans the entire range -2 to 0)

val ls=customers
.withColumn("count",count(($"amountSpent")).over(wSpec1))
.withColumn("movingAvg",when($"count"===3,avg(customers("amountSpent")).over(wSpec1)))

ls.show()


+-----+----------+-----------+-----+------------------+
| name|      date|amountSpent|count|         movingAvg|
+-----+----------+-----------+-----+------------------+
|  Bob|2016-01-01|       25.0|    1|              null|
|  Bob|2016-02-02|       25.0|    2|              null|
|  Bob|2016-03-03|       25.0|    3|              25.0|
|  Bob|2016-04-04|       29.0|    3|26.333333333333332|
|  Bob|2016-05-06|       27.0|    3|              27.0|
|Alice|2016-01-01|       50.0|    1|              null|
|Alice|2016-02-03|       45.0|    2|              null|
|Alice|2016-03-04|       55.0|    3|              50.0|
|Alice|2016-04-05|       60.0|    3|53.333333333333336|
|Alice|2016-05-06|       65.0|    3|              60.0|
+-----+----------+-----------+-----+------------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM