簡體   English   中英

用戶按數據框分組時無法執行用戶定義的功能

[英]Failed to execute user defined function when aggregating in a dataframe groupby user

我有一個如下數據框,我正在嘗試獲取用戶groupby名稱的max(sum)。

+-----+-----------------------------+
|name |nt_set                       |
+-----+-----------------------------+
|Bob  |[av:27.0, bcd:29.0, abc:25.0]|
|Alice|[abc:95.0, bcd:55.0]         |
|Bob  |[abc:95.0, bcd:70.0]         |
|Alice|[abc:125.0, bcd:90.0]        |
+-----+-----------------------------+

以下是我用來獲取用戶的max(sum)的udf

val maxfunc = udf((arr: Array[String]) => {
val step1 = arr.map(x => (x.split(":", -1)(0), x.split(":", -1)(1))).groupBy(_._1).mapValues(arr => arr.map(_._2.toInt).sum).maxBy(_._2)
val result = step1._1 + ":" + step1._2
result})

當我運行udf時,它會拋出以下錯誤

 val c6 = c5.withColumn("max_nt", maxfunc(col("nt_set"))).show(false)

錯誤:無法執行用戶定義的函數($ anonfun $ 1 :(數組)=>字符串)

如何以更好的執行方式實現此目標,因為我需要在更大的數據集中進行

預期的結果是

expected result:
+-----+-----------------------------+
|name |max_nt                       |
+-----+-----------------------------+
|Bob  |abc:120.0                    |
|Alice|abc:220.0                    |
+-----+-----------------------------+

從我對您正在嘗試做的事情的了解中,您的例子是錯誤的。 愛麗絲的bcd字段總計為145,而她的abc字段總計為220。因此,也應為她選擇abc。 如果我錯了,那我就誤解了你的問題。

無論如何,您不需要udf即可完成所需的工作。 讓我們生成您的數據:

val df = sc.parallelize(Seq(
    ("Bob", Array("av:27.0", "bcd:29.0", "abc:25.0")), 
    ("Alice", Array("abc:95.0", "bcd:55.0")), 
    ("Bob", Array("abc:95.0", "bcd:70.0")), 
    ("Alice", Array("abc:125.0", "bcd:90.0"))) )
        .toDF("name", "nt_set")

然后,一種方法是將nt_set分解為僅包含一個字符串/值對的nt列。

df.withColumn("nt", explode('nt_set))
  //then we split the string and the value
  .withColumn("nt_string", split('nt, ":")(0))
  .withColumn("nt_value", split('nt, ":")(1).cast("int"))
  //then we sum the values by name and "string"
  .groupBy("name", "nt_string")
  .agg(sum('nt_value) as "nt_value")
  /* then we build a struct with the value first to be able to select
     the nt field with max value while keeping the corresponding string */
  .withColumn("nt", struct('nt_value, 'nt_string))
  .groupBy("name")
  .agg(max('nt) as "nt")
  // And we rebuild the "nt" column.
  .withColumn("max_nt", concat_ws(":", $"nt.nt_string", $"nt.nt_value"))
  .drop("nt").show(false)

+-----+-------+
|name |max_nt |
+-----+-------+
|Bob  |abc:120|
|Alice|abc:220|
+-----+-------+

maxfunc的核心邏輯可以maxfunc工作,只是它應該處理post-groupBy數組列,它是一個嵌套的Seq集合:

val df = Seq(
  ("Bob", Seq("av:27.0", "bcd:29.0", "abc:25.0")),
  ("Alice", Seq("abc:95.0", "bcd:55.0")),
  ("Zack", Seq()),
  ("Bob", Seq("abc:50.0", null)),
  ("Bob", Seq("abc:95.0", "bcd:70.0")),
  ("Alice", Seq("abc:125.0", "bcd:90.0"))
).toDF("name", "nt_set")

import org.apache.spark.sql.functions._

val maxfunc = udf( (ss: Seq[Seq[String]]) => {
  val groupedSeq: Map[String, Double] = ss.flatMap(identity).
    collect{ case x if x != null => (x.split(":")(0), x.split(":")(1)) }.
    groupBy(_._1).mapValues(_.map(_._2.toDouble).sum)

  groupedSeq match {
    case x if x == Map.empty[String, Double] => ("", -999.0)
    case _ => groupedSeq.maxBy(_._2)
  }
} )

df.groupBy("name").agg(collect_list("nt_set").as("arr_nt")).
  withColumn("max_nt", maxfunc($"arr_nt")).
  select($"name", $"max_nt._1".as("max_key"), $"max_nt._2".as("max_val")).
  show
// +-----+-------+-------+
// | name|max_key|max_val|
// +-----+-------+-------+
// | Zack|       | -999.0|
// |  Bob|    abc|  170.0|
// |Alice|    abc|  220.0|
// +-----+-------+-------+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM