简体   繁体   English

如何在 Spark 中为数据集实现 Seq.grouped(size:Int): Seq[Seq[A]]

[英]How to implement the Seq.grouped(size:Int): Seq[Seq[A]] for Dataset in Spark

I want to try to implement the def grouped(size: Int): Iterator[Repr] that Seq has but for Dataset in Spark.我想尝试实现Seqdef grouped(size: Int): Iterator[Repr]但对于 Spark 中的Dataset

So the input should be ds: Dataset[A], size: Int and output Seq[Dataset[A]] where each of the Dataset[A] in the output can't be bigger than size .所以输入应该是ds: Dataset[A], size: Int和 output Seq[Dataset[A]]其中每个Dataset[A]在输出不能大于size

How should I proceed ?我应该如何进行? I tried with repartition and mapPartitions but I am not sure where to go from there.我试着repartitionmapPartitions ,但我不知道从哪里里去。

Thank you.谢谢你。

Edit : I found the glom method in RDD but it produce a RDD[Array[A]] how do I go from this to the other way around Array[RDD[A]] ?编辑:我在RDD找到了glom方法,但它产生了一个RDD[Array[A]]我如何从这个转到Array[RDD[A]]的其他方法?

here you go, something that you want给你,你想要的东西

/*
{"countries":"pp1"}
{"countries":"pp2"}
{"countries":"pp3"}
{"countries":"pp4"}
{"countries":"pp5"}
{"countries":"pp6"}
{"countries":"pp7"}
   */

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkConf, SparkContext};


object SparkApp extends App {

  override def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("Simple Application").setMaster("local").set("spark.ui.enabled", "false")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    val dataFrame: DataFrame = sqlContext.read.json("/data.json")

    val k = 3

    val windowSpec = Window.partitionBy("grouped").orderBy("countries")

    val newDF = dataFrame.withColumn("grouped", lit("grouping"))

    var latestDF = newDF.withColumn("row", row_number() over windowSpec)

    val totalCount = latestDF.count()
    var lowLimit = 0
    var highLimit = lowLimit + k

    while(lowLimit < totalCount){
      latestDF.where(s"row <= $highLimit and row > $lowLimit").show(false)
      lowLimit = lowLimit + k
      highLimit = highLimit + k
    }
  }
}

Here is the solution I found but I am not sure if that can works reliably:这是我找到的解决方案,但我不确定它是否可以可靠地工作:

  override protected def batch[A](
    input:     Dataset[A],
    batchSize: Int
  ): Seq[Dataset[A]] = {
    val count = input.count()
    val partitionQuantity = Math.ceil(count / batchSize).toInt

    input.randomSplit(Array.fill(partitionQuantity)(1.0 / partitionQuantity), seed = 0)
  }

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM