简体   繁体   中英

Spark/Scala: Remove some component from a DataFrame with Array typed column

The title may not be very clear. Let me explain what I want to achieve with an example. Starting with a DataFrame as following:

val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)), 
             (4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
             (7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
             (101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
             (5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
             .toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

I want to remove/drop some elements of the array column according to the information in id, dept and test column. Specifically, the 4 elements in each array correspond to the four id that is in CS dept, and the number is generated with ascend id order (meaning 1, 5, 7, 101). Now I want remove the elements in each array that corresponds to the ids that have test column as 1. In this example, the 2nd and 4th elements will be removed and the end result will look like this:

+---+----+----+----------+
| id|dept|test|     array|
+---+----+----+----------+
|  1|  CS|   0|[0.1, 0.4]|
|  4|  Ed|   0|[0.4, 0.3]|
|  7|  CS|   0|[0.2, 0.4]|
|101|  CS|   1|[0.5, 0.3]|
|  5|  CS|   1|[0.4, 0.6]|
+---+----+----+----------+

In order to avoid collecting all the results and do the manipulation in Scala. I would like to keep the operation in Spark DataFrame if possible. My thought to tackle this problem includes Two steps:

  1. Figure out the Index of array elements that need to be removed
  2. Apply the remove/drop operation

So far, I think I have figured out step 1 as following:

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._

    val w = Window.partitionBy("dept").orderBy("id")
    val studentIdIdx = df.select("id", "dept")
      .withColumn("Index", row_number().over(w))
      .where("dept = 'CS'").drop("dept")
    studentIdIdx.show()
    +---+-----+
    | id|Index|
    +---+-----+
    |  1|    1|
    |  5|    2|
    |  7|    3|
    |101|    4|
    +---+-----+
    val testIds = df.where("test = 1")
      .select($"id".as("test_id"))
    val testMask = studentIdIdx
      .join(testIds, studentIdIdx("id") === testIds("test_id"))
      .drop("id","test_id")
    testMask.show()
    +-----+
    |Index|
    +-----+
    |    2|
    |    4|
    +-----+

So my two related questions are:

  1. How to apply the remove/drop function to each array in each row with the Index? (I am open to suggestion for a better way to figure the Index as well)

  2. The real final DataFrame that I want should remove some more element on top of the above result. Specifically, for test=0 & dept=CS, it should remove the array element that correspond to the Index of the id. In this example, the 1st element in the row with id=1 and the 3rd element (original index before any removal) in the row with id=7 should be removed, and the real final result is:

     +---+----+----+----------+ | id|dept|test| array| +---+----+----+----------+ | 1| CS| 0|[0.4] | | 4| Ed| 0|[0.4, 0.3]| | 7| CS| 0|[0.2] | |101| CS| 1|[0.5, 0.3]| | 5| CS| 1|[0.4, 0.6]| +---+----+----+----------+ 

I mention the second point just in case there is a more efficient way can be applied to achieve both remove operations together. If not, I think I should be able to figure out how to do the second remove once I know how to use the Index information for remove operation. Thanks!

here my solution. To drop the indices, I would use an UDF:

val df = Seq(
  (1, "CS", 0, Seq(0.1, 0.2, 0.4, 0.5)),
  (4, "Ed", 0, Seq(0.4, 0.8, 0.3, 0.6)),
  (7, "CS", 0, Seq(0.2, 0.5, 0.4, 0.7)),
  (101, "CS", 1, Seq(0.5, 0.7, 0.3, 0.8)),
  (5, "CS", 1, Seq(0.4, 0.2, 0.6, 0.9))
).toDF("id", "dept", "test", "array")


val dropElements = udf(
  (array: Seq[Double], indices: Seq[Int]) =>
    array.zipWithIndex.filterNot { case (x, i) => indices.contains(i + 1) }.map(_._1)
)

df
  .withColumn("index_to_drop", row_number().over(Window.partitionBy($"dept").orderBy($"id")))
  .withColumn("index_to_drop", when($"test" === 1, $"index_to_drop"))
  .withColumn("indices_to_drop", collect_list($"index_to_drop").over(Window.partitionBy($"dept")))
  .withColumn("array", dropElements($"array", $"indices_to_drop"))
  .select($"id", $"dept", $"test", $"array")
  .show()


+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|          [0.1, 0.4]|
|  5|  CS|   1|          [0.4, 0.6]|
|  7|  CS|   0|          [0.2, 0.4]|
|101|  CS|   1|          [0.5, 0.3]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
+---+----+----+--------------------+

This is a different solution without UDF. I would advise you to avoid UDF's as much as you can. The posexplode function is available from spark 2.1.0. Also did not add comments sorry.

import org.apache.spark.sql.functions.posexplode
import org.apache.spark.sql.expressions.Window

val df = Seq((1, "CS", 0, Array(0.1, 0.2, 0.4, 0.5)), 
             (4, "Ed", 0, Array(0.4, 0.8, 0.3, 0.6)),
             (7, "CS", 0, Array(0.2, 0.5, 0.4, 0.7)),
             (101, "CS", 1, Array(0.5, 0.7, 0.3, 0.8)),
             (5, "CS", 1, Array(0.4, 0.2, 0.6, 0.9)))
             .toDF("id", "dept", "test", "arraytoprocess")

scala> df.show()
+---+----+----+--------------------+
| id|dept|test|      arraytoprocess|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

val columnIndicestoDrop = df.withColumn("zipRank",row_number().over(Window.partitionBy("dept")
    .orderBy("id")))
    .withColumn("pos",when($"test" === 1, $"zipRank"-1))
    .filter('pos.isNotNull)
    .select('pos)
    .distinct()

scala> columnIndicestoDrop.show()
+---+
|pos|
+---+
|  1|
|  3|
+---+

val dfwitharrayIndices = df.select('id,
    'dept,
    'test,
    'arraytoprocess,
    posexplode($"arraytoprocess") as Seq("pos", "val"))

scala> dfwitharrayIndices.show()
+---+----+----+--------------------+---+---+
| id|dept|test|      arraytoprocess|pos|val|
+---+----+----+--------------------+---+---+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  0|0.1|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  1|0.2|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  2|0.4|
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|  3|0.5|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  0|0.4|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  1|0.8|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  2|0.3|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|  3|0.6|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  0|0.2|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  1|0.5|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  2|0.4|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|  3|0.7|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  0|0.5|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  1|0.7|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  2|0.3|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|  3|0.8|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  0|0.4|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  1|0.2|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  2|0.6|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|  3|0.9|
+---+----+----+--------------------+---+---+

val finaldataFrame = dfwitharrayIndices
    .join(broadcast(columnIndicestoDrop),Seq("pos"),"leftanti")
    .select('id,'dept,'test,'val)
    .groupBy('id,'dept,'test)
    .agg(collect_list('val).as("finalarray"))

scala> finaldataFrame.show()
+---+----+----+----------+
| id|dept|test|finalarray|
+---+----+----+----------+
|  5|  CS|   1|[0.4, 0.6]|
|  4|  Ed|   0|[0.4, 0.3]|
|  1|  CS|   0|[0.1, 0.4]|
|  7|  CS|   0|[0.2, 0.4]|
|101|  CS|   1|[0.5, 0.3]|
+---+----+----+----------+

Assuming that you have initial dataframe as

+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

root
 |-- id: integer (nullable = false)
 |-- dept: string (nullable = true)
 |-- test: integer (nullable = false)
 |-- array: array (nullable = true)
 |    |-- element: double (containsNull = false)

You can apply window function to get the row number as

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val w = Window.partitionBy("dept").orderBy("id")

val tempdf = df.withColumn("Index", row_number().over(w))

which would give you

+---+----+----+--------------------+-----+
|id |dept|test|array               |Index|
+---+----+----+--------------------+-----+
|1  |CS  |0   |[0.1, 0.2, 0.4, 0.5]|1    |
|5  |CS  |1   |[0.4, 0.2, 0.6, 0.9]|2    |
|7  |CS  |0   |[0.2, 0.5, 0.4, 0.7]|3    |
|101|CS  |1   |[0.5, 0.7, 0.3, 0.8]|4    |
|4  |Ed  |0   |[0.4, 0.8, 0.3, 0.6]|1    |
+---+----+----+--------------------+-----+

Next step would be select rows with dept = CS and test = 1 and get the index list

val csStudentIdIdxToRemove = tempdf.filter("dept = 'CS' and test = '1'").select(collect_list(tempdf("Index"))).collect()(0).getAs[Seq[Int]](0)
//WrappedArray(2, 4)

Then you define a udf function to remove the elements from the array column using all of your logic as

def removeUdf = udf((array: Seq[Double], additionalIndex: Int) => additionalIndex match{
  case 0 => array.zipWithIndex.filterNot(x => csStudentIdIdxToRemove.contains(x._2 + 1)).map(_._1)
  case _ => {
    val withAdditionalIndex = csStudentIdIdxToRemove ++ Seq(additionalIndex)
    array.zipWithIndex.filterNot(x => withAdditionalIndex.contains(x._2 + 1)).map(_._1)
  }
})

and then call the udf function and drop the Index column

tempdf.withColumn("array", removeUdf(tempdf("array"), when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0))))
    .drop("Index")

Finally you should have your desired result

+---+----+----+----------+
|1  |CS  |0   |[0.4]     |
|5  |CS  |1   |[0.4, 0.6]|
|7  |CS  |0   |[0.2]     |
|101|CS  |1   |[0.5, 0.3]|
|4  |Ed  |0   |[0.4, 0.3]|
+---+----+----+----------+

I hope the answer is concise and helpful

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM