简体   繁体   中英

Spark/Scala: Operations on some components from a DataFrame with Array typed column

Let me explain what I want to achieve with an example. Starting with a DataFrame as following:

val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)), 
             (4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
             (7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
             (101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
             (5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
             .toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test|               array|
+---+----+----+--------------------+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+

I want to change some elements of the array column according to the information in id, dept and test column. I first add the Index to each row for different dept as following:

@transient val w = Window.partitionBy("dept").orderBy("id")
val tempdf = df.withColumn("Index", row_number().over(w))
tempdf.show
+---+----+----+--------------------+-----+
| id|dept|test|               array|Index|
+---+----+----+--------------------+-----+
|  1|  CS|   0|[0.1, 0.2, 0.4, 0.5]|    1|
|  5|  CS|   1|[0.4, 0.2, 0.6, 0.9]|    2|
|  7|  CS|   0|[0.2, 0.5, 0.4, 0.7]|    3|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.8]|    4|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|    1|
+---+----+----+--------------------+-----+

What I want to achieve is to minus a constant (0.1) from one element in array column with its location corresponds to the index of the row within each dept. For example, in "dept==CS" case, the final result should be:

+---+----+----+--------------------+-----+
| id|dept|test|               array|Index|
+---+----+----+--------------------+-----+
|  1|  CS|   0|[0.0, 0.2, 0.4, 0.5]|    1|
|  5|  CS|   1|[0.4, 0.1, 0.6, 0.9]|    2|
|  7|  CS|   0|[0.2, 0.5, 0.3, 0.7]|    3|
|101|  CS|   1|[0.5, 0.7, 0.3, 0.7]|    4|
|  4|  Ed|   0|[0.4, 0.8, 0.3, 0.6]|    1|
+---+----+----+--------------------+-----+

Currently, I am thinking of achieving this with udf as following:

def subUdf = udf((array: Seq[Double], dampFactor: Double, additionalIndex: Int) => additionalIndex match{
   case 0 => array
   case _ => { val temp = array.zipWithIndex
     var mask = Array.fill(array.length)(0.0)
     mask(additionalIndex-1) = dampFactor
     val tempAdj = temp.map(x => if (additionalIndex == (x._2+1)) (x._1-mask, x._2) else x)
       tempAdj.map(_._1)
             }
      }
  )
val dampFactor = 0.1
val finaldf = tempdf.withColumn("array", subUdf(tempdf("array"), dampFactor, when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0)))).drop("Index")

The udf has a compile error due to the overloading method:

Name: Compile Error
Message: <console>:34: error: overloaded method value - with alternatives:
  (x: Double)Double <and>
  (x: Float)Double <and>
  (x: Long)Double <and>
  (x: Int)Double <and>
  (x: Char)Double <and>
  (x: Short)Double <and>
  (x: Byte)Double
 cannot be applied to (Array[Double])
            val tempAdj = temp.map(x => if (additionalIndex == (x._2+1)) (x._1-mask, x._2) else x)
           ^

Two related questions:

  1. How to resolve the compile error?

  2. I am open to suggestion of using method other than udf to achieve this as well.

If I understand your requirement correctly, you can create a UDF that takes the dampFactor, the array column and the window index column to transform the dataframe as follows:

val df = Seq(
  (1, "CS", 0, Seq(0.1, 0.2, 0.4, 0.5)), 
  (4, "Ed", 0, Seq(0.4, 0.8, 0.3, 0.6)),
  (7, "CS", 0, Seq(0.2, 0.5, 0.4, 0.7)),
  (101, "CS", 1, Seq(0.5, 0.7, 0.3, 0.8)),
  (5, "CS", 1, Seq(0.4, 0.2, 0.6, 0.9))
).toDF("id", "dept", "test", "array")

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("dept").orderBy("id")
val tempdf = df.withColumn("index", row_number().over(w))

def adjustSeq(dampFactor: Double) = udf(
  (seq: Seq[Double], index: Int) =>
    seq.indices.map(i =>
      if (i == index - 1) seq(i) - dampFactor else seq(i)
    )
)

val finaldf = tempdf.
  withColumn("array", adjustSeq(0.1)($"array", $"index")).
  drop("index")

finaldf.show(false)
// +---+----+----+------------------------------------+
// |id |dept|test|array                               |
// +---+----+----+------------------------------------+
// |1  |CS  |0   |[0.0, 0.2, 0.4, 0.5]                |
// |5  |CS  |1   |[0.4, 0.1, 0.6, 0.9]                |
// |7  |CS  |0   |[0.2, 0.5, 0.30000000000000004, 0.7]|
// |101|CS  |1   |[0.5, 0.7, 0.3, 0.7000000000000001] |
// |4  |Ed  |0   |[0.30000000000000004, 0.8, 0.3, 0.6]|
// +---+----+----+------------------------------------+

Your sample code appears to include some additional logic not described in the requirement:

val finaldf = tempdf.withColumn("array", subUdf(tempdf("array"), dampFactor, when(tempdf("dept") === "CS" && tempdf("test") === 0, tempdf("Index")).otherwise(lit(0)))).drop("Index")

To factor in the additional logic:

def adjustSeq(dampFactor: Double) = udf(
  (seq: Seq[Double], index: Int, dept: String, test: Int) =>
    (`dept`, `test`) match {
      case ("CS", 0) =>
        seq.indices.map(i =>
          if (i == index - 1) seq(i) - dampFactor else seq(i)
        )
      case _ => seq
    }
)

val finaldf = tempdf.
  withColumn("array", adjustSeq(0.1)($"array", $"index", $"dept", $"test")).
  drop("index")

finaldf.show(false)
// +---+----+----+------------------------------------+
// |id |dept|test|array                               |
// +---+----+----+------------------------------------+
// |1  |CS  |0   |[0.0, 0.2, 0.4, 0.5]                |
// |5  |CS  |1   |[0.4, 0.2, 0.6, 0.9]                |
// |7  |CS  |0   |[0.2, 0.5, 0.30000000000000004, 0.7]|
// |101|CS  |1   |[0.5, 0.7, 0.3, 0.8]                |
// |4  |Ed  |0   |[0.4, 0.8, 0.3, 0.6]                |
// +---+----+----+------------------------------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM