[英]Spark/Scala: Remove some component from a DataFrame with Array typed column
標題可能不太清楚。 讓我用一個例子來解釋我要實現的目標。 從DataFrame開始,如下所示:
val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
.toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
我想根據id,dept和test列中的信息刪除/刪除array列的某些元素。 具體來說,每個數組中的4個元素對應於CS dept中的4個id,並且數字以id升序生成(表示1、5、7、101)。 現在,我要刪除每個數組中與具有test列的id對應的ID的元素。在此示例中,將刪除第二個和第四個元素,最終結果將如下所示:
+---+----+----+----------+
| id|dept|test| array|
+---+----+----+----------+
| 1| CS| 0|[0.1, 0.4]|
| 4| Ed| 0|[0.4, 0.3]|
| 7| CS| 0|[0.2, 0.4]|
|101| CS| 1|[0.5, 0.3]|
| 5| CS| 1|[0.4, 0.6]|
+---+----+----+----------+
為了避免收集所有結果並在Scala中進行操作。 我想盡可能地將操作保留在Spark DataFrame中。 解決這個問題的想法包括兩個步驟:
到目前為止,我想我已經確定了步驟1如下:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val w = Window.partitionBy("dept").orderBy("id")
val studentIdIdx = df.select("id", "dept")
.withColumn("Index", row_number().over(w))
.where("dept = 'CS'").drop("dept")
studentIdIdx.show()
+---+-----+
| id|Index|
+---+-----+
| 1| 1|
| 5| 2|
| 7| 3|
|101| 4|
+---+-----+
val testIds = df.where("test = 1")
.select($"id".as("test_id"))
val testMask = studentIdIdx
.join(testIds, studentIdIdx("id") === testIds("test_id"))
.drop("id","test_id")
testMask.show()
+-----+
|Index|
+-----+
| 2|
| 4|
+-----+
所以我的兩個相關問題是:
如何將刪除/刪除功能應用於帶有索引的每一行中的每個數組? (我也願意提出一種更好的方法來計算指數)
我想要的真正的最終DataFrame應該在上述結果之上刪除更多元素。 具體來說,對於test = 0和dept = CS,它應刪除與id的索引對應的數組元素。 在此示例中,應刪除id = 1的行中的第1個元素和id = 7的行中的第3個元素(任何刪除之前的原始索引),真正的最終結果是:
+---+----+----+----------+ | id|dept|test| array| +---+----+----+----------+ | 1| CS| 0|[0.4] | | 4| Ed| 0|[0.4, 0.3]| | 7| CS| 0|[0.2] | |101| CS| 1|[0.5, 0.3]| | 5| CS| 1|[0.4, 0.6]| +---+----+----+----------+
我提到第二點是為了以防萬一,可以采用一種更有效的方法來同時實現這兩個刪除操作。 如果沒有,我想我一旦知道如何使用索引信息進行刪除操作,便應該能夠弄清楚如何進行第二次刪除。 謝謝!
這是我的解決方案。 要刪除索引,我將使用UDF:
val df = Seq(
(1, "CS", 0, Seq(0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, Seq(0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, Seq(0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, Seq(0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, Seq(0.4, 0.2, 0.6, 0.9))
).toDF("id", "dept", "test", "array")
val dropElements = udf(
(array: Seq[Double], indices: Seq[Int]) =>
array.zipWithIndex.filterNot { case (x, i) => indices.contains(i + 1) }.map(_._1)
)
df
.withColumn("index_to_drop", row_number().over(Window.partitionBy($"dept").orderBy($"id")))
.withColumn("index_to_drop", when($"test" === 1, $"index_to_drop"))
.withColumn("indices_to_drop", collect_list($"index_to_drop").over(Window.partitionBy($"dept")))
.withColumn("array", dropElements($"array", $"indices_to_drop"))
.select($"id", $"dept", $"test", $"array")
.show()
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0| [0.1, 0.4]|
| 5| CS| 1| [0.4, 0.6]|
| 7| CS| 0| [0.2, 0.4]|
|101| CS| 1| [0.5, 0.3]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
+---+----+----+--------------------+
這是沒有UDF的另一種解決方案。 我建議您盡量避免使用UDF。 posexplode功能可從spark 2.1.0獲得。 還沒有添加評論抱歉。
import org.apache.spark.sql.functions.posexplode
import org.apache.spark.sql.expressions.Window
val df = Seq((1, "CS", 0, Array(0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, Array(0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, Array(0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, Array(0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, Array(0.4, 0.2, 0.6, 0.9)))
.toDF("id", "dept", "test", "arraytoprocess")
scala> df.show()
+---+----+----+--------------------+
| id|dept|test| arraytoprocess|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
val columnIndicestoDrop = df.withColumn("zipRank",row_number().over(Window.partitionBy("dept")
.orderBy("id")))
.withColumn("pos",when($"test" === 1, $"zipRank"-1))
.filter('pos.isNotNull)
.select('pos)
.distinct()
scala> columnIndicestoDrop.show()
+---+
|pos|
+---+
| 1|
| 3|
+---+
val dfwitharrayIndices = df.select('id,
'dept,
'test,
'arraytoprocess,
posexplode($"arraytoprocess") as Seq("pos", "val"))
scala> dfwitharrayIndices.show()
+---+----+----+--------------------+---+---+
| id|dept|test| arraytoprocess|pos|val|
+---+----+----+--------------------+---+---+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]| 0|0.1|
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]| 1|0.2|
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]| 2|0.4|
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]| 3|0.5|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 0|0.4|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 1|0.8|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 2|0.3|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]| 3|0.6|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]| 0|0.2|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]| 1|0.5|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]| 2|0.4|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]| 3|0.7|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]| 0|0.5|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]| 1|0.7|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]| 2|0.3|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]| 3|0.8|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]| 0|0.4|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]| 1|0.2|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]| 2|0.6|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]| 3|0.9|
+---+----+----+--------------------+---+---+
val finaldataFrame = dfwitharrayIndices
.join(broadcast(columnIndicestoDrop),Seq("pos"),"leftanti")
.select('id,'dept,'test,'val)
.groupBy('id,'dept,'test)
.agg(collect_list('val).as("finalarray"))
scala> finaldataFrame.show()
+---+----+----+----------+
| id|dept|test|finalarray|
+---+----+----+----------+
| 5| CS| 1|[0.4, 0.6]|
| 4| Ed| 0|[0.4, 0.3]|
| 1| CS| 0|[0.1, 0.4]|
| 7| CS| 0|[0.2, 0.4]|
|101| CS| 1|[0.5, 0.3]|
+---+----+----+----------+
假設您有初始數據框為
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
root
|-- id: integer (nullable = false)
|-- dept: string (nullable = true)
|-- test: integer (nullable = false)
|-- array: array (nullable = true)
| |-- element: double (containsNull = false)
您可以應用窗口函數來獲取行號為
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val w = Window.partitionBy("dept").orderBy("id")
val tempdf = df.withColumn("Index", row_number().over(w))
這會給你
+---+----+----+--------------------+-----+
|id |dept|test|array |Index|
+---+----+----+--------------------+-----+
|1 |CS |0 |[0.1, 0.2, 0.4, 0.5]|1 |
|5 |CS |1 |[0.4, 0.2, 0.6, 0.9]|2 |
|7 |CS |0 |[0.2, 0.5, 0.4, 0.7]|3 |
|101|CS |1 |[0.5, 0.7, 0.3, 0.8]|4 |
|4 |Ed |0 |[0.4, 0.8, 0.3, 0.6]|1 |
+---+----+----+--------------------+-----+
下一步將選擇具有dept = CS和test = 1的行並獲取索引列表
val csStudentIdIdxToRemove = tempdf.filter("dept = 'CS' and test = '1'").select(collect_list(tempdf("Index"))).collect()(0).getAs[Seq[Int]](0)
//WrappedArray(2, 4)
然后,您定義一個udf函數,使用所有邏輯將數組列中的元素刪除為
def removeUdf = udf((array: Seq[Double], additionalIndex: Int) => additionalIndex match{
case 0 => array.zipWithIndex.filterNot(x => csStudentIdIdxToRemove.contains(x._2 + 1)).map(_._1)
case _ => {
val withAdditionalIndex = csStudentIdIdxToRemove ++ Seq(additionalIndex)
array.zipWithIndex.filterNot(x => withAdditionalIndex.contains(x._2 + 1)).map(_._1)
}
})
然后調用udf函數並刪除Index列
tempdf.withColumn("array", removeUdf(tempdf("array"), when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0))))
.drop("Index")
最后你應該有你想要的結果
+---+----+----+----------+
|1 |CS |0 |[0.4] |
|5 |CS |1 |[0.4, 0.6]|
|7 |CS |0 |[0.2] |
|101|CS |1 |[0.5, 0.3]|
|4 |Ed |0 |[0.4, 0.3]|
+---+----+----+----------+
我希望答案簡潔明了,有幫助
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.