简体   繁体   中英

Spark collect_list and limit resulting list

I have a dataframe of the following format:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

What I want to do is group the dataframe by the name , collect the list and limit the size of the list.

This is how i group by the name and collect the list:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

The resuling dataframe is something like:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

What I want to do is limit the size of the produced lists for each key. I' ve tried multiple ways to do that but had no success. I've already seen some posts that suggest 3rd party solutions but I want to avoid that. Is there a way?

You can create a function that limits the size of the aggregated ArrayType column as shown below:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+

So while a UDF does what you need, if you're looking for a more performant way that is also memory sensitive, the way of doing this would be to write a UDAF. Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark. However you can use their internal APIs to build on the internal functions to do what you need.

Here is an implementation for collect_list_limit that is mostly a copy past of Spark's internal CollectList AggregateFunction. I would just extend it but its a case class. Really all that's needed is to override update and merge methods to respect a passed in limit:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectListLimit using the provided expressions:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

Edit:

Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup. If you have an existing context then this should work to add it with reflection:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

You can use a UDF.

Here is a probable example without the necessity of schema and with a meaningful reduction:

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob1 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("key", 1L, "gargamel"),
  ("key", 4L, "pe_gadol"),
  ("key", 2L, "zaam"),
  ("key1", 5L, "naval")
).toDF("group", "quality", "other")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByQuality, rawSchema)

val aggDf = rawDf
  .groupBy("group")
  .agg(
    count(struct("*")).as("num_reads"),
    max(col("quality")).as("quality"),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .drop("horizontal")


aggDf.printSchema

aggDf.show(false)
}

def reduceByQuality= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val quality1 = r1.getAs[Long]("quality")
  val quality2 = r2.getAs[Long]("quality")

  val r3 = quality1 match {
    case a if a >= quality2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}
}

here is an example with data like yours

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._


val df1 = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
)
  .toDF("name", "merged")

//    df1.printSchema
//
//    df1.show(false)

val res = df1
  .groupBy("name")
  .agg( collect_list(col("merged")).as("final") )

res.printSchema

res.show(false)

def f= (x: Any) => {

  val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

  val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head

  d1.toString
}

val fUdf = udf(f, StringType)

val d2 = res
  .withColumn("d", fUdf(col("final")))
  .drop("final")

d2.printSchema()

d2
  .show(false)
 }
 }

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM