簡體   English   中英

Spark / Scala:從帶有數組類型列的DataFrame中刪除某些組件

[英]Spark/Scala: Remove some component from a DataFrame with Array typed column

標題可能不太清楚。 讓我用一個例子來解釋我要實現的目標。 從DataFrame開始,如下所示:

val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)), 
             (4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
             (7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
             (101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
             (5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
             .toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

我想根據id,dept和test列中的信息刪除/刪除array列的某些元素。 具體來說,每個數組中的4個元素對應於CS dept中的4個id,並且數字以id升序生成(表示1、5、7、101)。 現在,我要刪除每個數組中與具有test列的id對應的ID的元素。在此示例中,將刪除第二個和第四個元素,最終結果將如下所示:

+---+----+----+----------+
| id|dept|test|     array|
+---+----+----+----------+
|  1|  CS|   0|[0.1, 0.4]|
|  4|  Ed|   0|[0.4, 0.3]|
|  7|  CS|   0|[0.2, 0.4]|
|101|  CS|   1|[0.5, 0.3]|
|  5|  CS|   1|[0.4, 0.6]|
+---+----+----+----------+

為了避免收集所有結果並在Scala中進行操作。 我想盡可能地將操作保留在Spark DataFrame中。 解決這個問題的想法包括兩個步驟:

  1. 找出需要刪除的數組元素的索引
  2. 應用刪除/刪除操作

到目前為止,我想我已經確定了步驟1如下:

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._

    val w = Window.partitionBy("dept").orderBy("id")
    val studentIdIdx = df.select("id", "dept")
      .withColumn("Index", row_number().over(w))
      .where("dept = 'CS'").drop("dept")
    studentIdIdx.show()
    +---+-----+
    | id|Index|
    +---+-----+
    |  1|    1|
    |  5|    2|
    |  7|    3|
    |101|    4|
    +---+-----+
    val testIds = df.where("test = 1")
      .select($"id".as("test_id"))
    val testMask = studentIdIdx
      .join(testIds, studentIdIdx("id") === testIds("test_id"))
      .drop("id","test_id")
    testMask.show()
    +-----+
    |Index|
    +-----+
    |    2|
    |    4|
    +-----+

所以我的兩個相關問題是:

  1. 如何將刪除/刪除功能應用於帶有索引的每一行中的每個數組? (我也願意提出一種更好的方法來計算指數)

  2. 我想要的真正的最終DataFrame應該在上述結果之上刪除更多元素。 具體來說,對於test = 0和dept = CS,它應刪除與id的索引對應的數組元素。 在此示例中,應刪除id = 1的行中的第1個元素和id = 7的行中的第3個元素(任何刪除之前的原始索引),真正的最終結果是:

     +---+----+----+----------+ | id|dept|test| array| +---+----+----+----------+ | 1| CS| 0|[0.4] | | 4| Ed| 0|[0.4, 0.3]| | 7| CS| 0|[0.2] | |101| CS| 1|[0.5, 0.3]| | 5| CS| 1|[0.4, 0.6]| +---+----+----+----------+ 

我提到第二點是為了以防萬一,可以采用一種更有效的方法來同時實現這兩個刪除操作。 如果沒有,我想我一旦知道如何使用索引信息進行刪除操作,便應該能夠弄清楚如何進行第二次刪除。 謝謝!

這是我的解決方案。 要刪除索引,我將使用UDF:

val df = Seq(
  (1, "CS", 0, Seq(0.1, 0.2, 0.4, 0.5)),
  (4, "Ed", 0, Seq(0.4, 0.8, 0.3, 0.6)),
  (7, "CS", 0, Seq(0.2, 0.5, 0.4, 0.7)),
  (101, "CS", 1, Seq(0.5, 0.7, 0.3, 0.8)),
  (5, "CS", 1, Seq(0.4, 0.2, 0.6, 0.9))
).toDF("id", "dept", "test", "array")


val dropElements = udf(
  (array: Seq[Double], indices: Seq[Int]) =>
    array.zipWithIndex.filterNot { case (x, i) => indices.contains(i + 1) }.map(_._1)
)

df
  .withColumn("index_to_drop", row_number().over(Window.partitionBy($"dept").orderBy($"id")))
  .withColumn("index_to_drop", when($"test" === 1, $"index_to_drop"))
  .withColumn("indices_to_drop", collect_list($"index_to_drop").over(Window.partitionBy($"dept")))
  .withColumn("array", dropElements($"array", $"indices_to_drop"))
  .select($"id", $"dept", $"test", $"array")
  .show()


+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|          [0.1, 0.4]|
|  5|  CS|   1|          [0.4, 0.6]|
|  7|  CS|   0|          [0.2, 0.4]|
|101|  CS|   1|          [0.5, 0.3]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
+---+----+----+--------------------+

這是沒有UDF的另一種解決方案。 我建議您盡量避免使用UDF。 posexplode功能可從spark 2.1.0獲得。 還沒有添加評論抱歉。

import org.apache.spark.sql.functions.posexplode
import org.apache.spark.sql.expressions.Window

val df = Seq((1, "CS", 0, Array(0.1, 0.2, 0.4, 0.5)), 
             (4, "Ed", 0, Array(0.4, 0.8, 0.3, 0.6)),
             (7, "CS", 0, Array(0.2, 0.5, 0.4, 0.7)),
             (101, "CS", 1, Array(0.5, 0.7, 0.3, 0.8)),
             (5, "CS", 1, Array(0.4, 0.2, 0.6, 0.9)))
             .toDF("id", "dept", "test", "arraytoprocess")

scala> df.show()
+---+----+----+--------------------+
| id|dept|test|      arraytoprocess|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

val columnIndicestoDrop = df.withColumn("zipRank",row_number().over(Window.partitionBy("dept")
    .orderBy("id")))
    .withColumn("pos",when($"test" === 1, $"zipRank"-1))
    .filter('pos.isNotNull)
    .select('pos)
    .distinct()

scala> columnIndicestoDrop.show()
+---+
|pos|
+---+
|  1|
|  3|
+---+

val dfwitharrayIndices = df.select('id,
    'dept,
    'test,
    'arraytoprocess,
    posexplode($"arraytoprocess") as Seq("pos", "val"))

scala> dfwitharrayIndices.show()
+---+----+----+--------------------+---+---+
| id|dept|test|      arraytoprocess|pos|val|
+---+----+----+--------------------+---+---+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  0|0.1|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  1|0.2|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  2|0.4|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  3|0.5|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  0|0.4|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  1|0.8|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  2|0.3|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  3|0.6|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  0|0.2|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  1|0.5|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  2|0.4|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  3|0.7|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  0|0.5|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  1|0.7|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  2|0.3|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  3|0.8|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  0|0.4|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  1|0.2|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  2|0.6|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  3|0.9|
+---+----+----+--------------------+---+---+

val finaldataFrame = dfwitharrayIndices
    .join(broadcast(columnIndicestoDrop),Seq("pos"),"leftanti")
    .select('id,'dept,'test,'val)
    .groupBy('id,'dept,'test)
    .agg(collect_list('val).as("finalarray"))

scala> finaldataFrame.show()
+---+----+----+----------+
| id|dept|test|finalarray|
+---+----+----+----------+
|  5|  CS|   1|[0.4, 0.6]|
|  4|  Ed|   0|[0.4, 0.3]|
|  1|  CS|   0|[0.1, 0.4]|
|  7|  CS|   0|[0.2, 0.4]|
|101|  CS|   1|[0.5, 0.3]|
+---+----+----+----------+

假設您有初始數據框

+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

root
 |-- id: integer (nullable = false)
 |-- dept: string (nullable = true)
 |-- test: integer (nullable = false)
 |-- array: array (nullable = true)
 |    |-- element: double (containsNull = false)

您可以應用窗口函數來獲取行號為

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val w = Window.partitionBy("dept").orderBy("id")

val tempdf = df.withColumn("Index", row_number().over(w))

這會給你

+---+----+----+--------------------+-----+
|id |dept|test|array               |Index|
+---+----+----+--------------------+-----+
|1  |CS  |0   |[0.1, 0.2, 0.4, 0.5]|1    |
|5  |CS  |1   |[0.4, 0.2, 0.6, 0.9]|2    |
|7  |CS  |0   |[0.2, 0.5, 0.4, 0.7]|3    |
|101|CS  |1   |[0.5, 0.7, 0.3, 0.8]|4    |
|4  |Ed  |0   |[0.4, 0.8, 0.3, 0.6]|1    |
+---+----+----+--------------------+-----+

下一步將選擇具有dept = CS和test = 1的行並獲取索引列表

val csStudentIdIdxToRemove = tempdf.filter("dept = 'CS' and test = '1'").select(collect_list(tempdf("Index"))).collect()(0).getAs[Seq[Int]](0)
//WrappedArray(2, 4)

然后,您定義一個udf函數,使用所有邏輯將數組列中的元素刪除為

def removeUdf = udf((array: Seq[Double], additionalIndex: Int) => additionalIndex match{
  case 0 => array.zipWithIndex.filterNot(x => csStudentIdIdxToRemove.contains(x._2 + 1)).map(_._1)
  case _ => {
    val withAdditionalIndex = csStudentIdIdxToRemove ++ Seq(additionalIndex)
    array.zipWithIndex.filterNot(x => withAdditionalIndex.contains(x._2 + 1)).map(_._1)
  }
})

然后調用udf函數並刪除Index列

tempdf.withColumn("array", removeUdf(tempdf("array"), when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0))))
    .drop("Index")

最后你應該有你想要的結果

+---+----+----+----------+
|1  |CS  |0   |[0.4]     |
|5  |CS  |1   |[0.4, 0.6]|
|7  |CS  |0   |[0.2]     |
|101|CS  |1   |[0.5, 0.3]|
|4  |Ed  |0   |[0.4, 0.3]|
+---+----+----+----------+

我希望答案簡潔明了,有幫助

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM