简体   繁体   English

Spark collect_list 并限制结果列表

[英]Spark collect_list and limit resulting list

I have a dataframe of the following format:我有以下格式的数据框:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

What I want to do is group the dataframe by the name , collect the list and limit the size of the list.我想要做的是按name对数据框进行分组,收集列表并限制列表的大小。

This is how i group by the name and collect the list:这是我按name分组并收集列表的方式:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

The resuling dataframe is something like:结果数据框类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

What I want to do is limit the size of the produced lists for each key.我想要做的是限制每个键生成的列表的大小。 I' ve tried multiple ways to do that but had no success.我尝试了多种方法来做到这一点,但都没有成功。 I've already seen some posts that suggest 3rd party solutions but I want to avoid that.我已经看到一些建议使用 3rd 方解决方案的帖子,但我想避免这种情况。 Is there a way?有办法吗?

You can create a function that limits the size of the aggregated ArrayType column as shown below:您可以创建一个函数来限制聚合 ArrayType 列的大小,如下所示:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+

So while a UDF does what you need, if you're looking for a more performant way that is also memory sensitive, the way of doing this would be to write a UDAF.因此,虽然 UDF 可以满足您的需求,但如果您正在寻找一种对内存敏感的更高性能的方式,那么这样做的方法是编写 UDAF。 Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark.不幸的是,UDAF API 实际上不像 spark 附带的聚合函数那样可扩展。 However you can use their internal APIs to build on the internal functions to do what you need.但是,您可以使用他们的内部 API 来构建内部函数来执行您需要的操作。

Here is an implementation for collect_list_limit that is mostly a copy past of Spark's internal CollectList AggregateFunction.这是collect_list_limit一个实现,它主要是 Spark 内部CollectList AggregateFunction 的副本。 I would just extend it but its a case class.我只是扩展它,但它是一个案例类。 Really all that's needed is to override update and merge methods to respect a passed in limit:真正需要的是覆盖更新和合并方法以遵守传入的限制:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectListLimit using the provided expressions:为了实际注册它,我们可以通过 Spark 的内部FunctionRegistry来完成它,它接受名称和构建器,它实际上是一个使用提供的表达式创建CollectListLimit的函数:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

Edit:编辑:

Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup.事实证明,仅当您尚未创建 SparkContext 时,将其添加到内置函数中才有效,因为它在启动时创建了一个不可变的克隆。 If you have an existing context then this should work to add it with reflection:如果您有一个现有的上下文,那么这应该可以通过反射添加它:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

You can use a UDF.您可以使用 UDF。

Here is a probable example without the necessity of schema and with a meaningful reduction:这是一个可能的示例,不需要模式并且进行了有意义的缩减:

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob1 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("key", 1L, "gargamel"),
  ("key", 4L, "pe_gadol"),
  ("key", 2L, "zaam"),
  ("key1", 5L, "naval")
).toDF("group", "quality", "other")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByQuality, rawSchema)

val aggDf = rawDf
  .groupBy("group")
  .agg(
    count(struct("*")).as("num_reads"),
    max(col("quality")).as("quality"),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .drop("horizontal")


aggDf.printSchema

aggDf.show(false)
}

def reduceByQuality= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val quality1 = r1.getAs[Long]("quality")
  val quality2 = r2.getAs[Long]("quality")

  val r3 = quality1 match {
    case a if a >= quality2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}
}

here is an example with data like yours这是一个像你这样的数据的例子

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._


val df1 = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
)
  .toDF("name", "merged")

//    df1.printSchema
//
//    df1.show(false)

val res = df1
  .groupBy("name")
  .agg( collect_list(col("merged")).as("final") )

res.printSchema

res.show(false)

def f= (x: Any) => {

  val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

  val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head

  d1.toString
}

val fUdf = udf(f, StringType)

val d2 = res
  .withColumn("d", fUdf(col("final")))
  .drop("final")

d2.printSchema()

d2
  .show(false)
 }
 }

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM