简体   繁体   中英

How to get the avg values in front of the current position in a RDD with spark/scala

I have a RDD , I want to get the average values in front of the current position(including current position) in a RDD for example:

inputRDD:
1,  2,   3,  4,   5,  6,   7,  8

output:
1,  1.5, 2,  2.5, 3,  3.5, 4,  4.5

this is my try:

val rdd=sc.parallelize(List(1,2,3,4,5,6,7,8),4)
    var sum=0.0
    var index=0.0
    val partition=rdd.getNumPartitions
    rdd.zipWithIndex().collect().foreach(println)
    rdd.zipWithIndex().sortBy(x=>{x._2},true,1).mapPartitions(ite=>{
      var result=new ArrayBuffer[Tuple2[Double,Long]]()
      while (ite.hasNext){
        val iteNext=ite.next()
        sum+=iteNext._1
        index+=1
        var avg:Double=sum/index
        result.append((avg,iteNext._2))
      }
      result.toIterator
    }).sortBy(x=>{x._2},true,partition).map(x=>{x._1}).collect().foreach(println)

I have to repartition to 1 then calculate it with a array,it's so inefficient.

Is there any cleaner solution without using array in 4 partitions?

a simpler solution would be to use Spark-SQL. here I am computing the running average for each row

val df = sc.parallelize(List(1,2,3,4,5,6,7,8)).toDF("col1")

df.createOrReplaceTempView("table1")

val result = spark.sql("""SELECT col1, sum(col1) over(order by col1 asc)/row_number() over(order by col1 asc) as avg FROM table1""")

or alternatively if you want to use the DataFrames API.

import org.apache.spark.sql.expressions._
val result = df
 .withColumn("csum", sum($"col1").over(Window.orderBy($"col1")))
 .withColumn("rownum", row_number().over(Window.orderBy($"col1")))
 .withColumn("avg", $"csum"/$"rownum")
 .select("col1","avg")

Output :

result.show()

+----+---+
|col1|avg|
+----+---+
|   1|1.0|
|   2|1.5|
|   3|2.0|
|   4|2.5|
|   5|3.0|
|   6|3.5|
|   7|4.0|
|   8|4.5|
+----+---+

Sorry I dont use Scala and hope you could read it

df = spark.createDataFrame(map(lambda x: (x,), range(1, 9)), ['val'])
df = df.withColumn('spec_avg',
                   f.avg('val').over(Window().orderBy('val').rowsBetween(start=Window.unboundedPreceding, end=0)))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM