简体   繁体   中英

How to create a new column for dataset using “.withColumn” with many conditions in Scala Spark

I have the following input array

val bins = (("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))

Basically the strings "bin1" refer to values in a reference column on which dataframe is filtered - a new column is created from another column based on boundry conditions in remaining two doubles in the array

var number_of_dataframes = bins.length
var ctempdf = spark.createDataFrame(sc.emptyRDD[Row],train_data.schema)
ctempdf = ctempdf.withColumn(colName,col(colName))
val t1 = System.nanoTime
for ( x<- 0 to binputs.length-1)

{
      var tempdf = train_data.filter(col(refCol) === bins(x)._1)
      //println(binputs(x)._1)
      tempdf = tempdf.withColumn(colName,
                                 when(col(colName) < bins(x)._2, bins(x)._2)
                                 when(col(colName) > bins(x)._3, bins(x)._3)
                                 otherwise(col(colName)))
      ctempdf = ctempdf.union(tempdf)
val duration = (System.nanoTime - t1) / 1e9d
println(duration)     
}

The code above works incrementally slowly for every increasing value of bins - Is there a way I can speed this up drastically - because this code is again nested in another loop.

I have used checkpoint / persist / cache and these are not helping

There is no need for iterative union here. Create a literal map<string, struct<double, double>> using oassql.functions.map (in functional terms it behaves like delayed string => struct<lower: dobule, upper: double> )

import org.apache.spark.sql.functions._

val bins: Seq[(String, Double Double)] = Seq(
  ("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))

val binCol = map(bins.map { 
  case (key, lower, upper) => Seq(
    lit(key), 
    struct(lit(lower) as "lower", lit(upper) as "upper")) 
}.flatten: _*)

define expressions like these (these are simple lookups in predefined mapping, so binCol(col(refCol)) is delayed struct<lower: dobule, upper: double> and the remaining apply takes the lower or upper field):

val lower = binCol(col(refCol))("lower")
val upper =  binCol(col(refCol))("upper")
val c = col(colName)

and use CASE ... WHEN ... ( Spark Equivalent of IF Then ELSE )

val result = when(c.between(lower, upper), c)
  .when(c < lower, lower)
  .when(c > upper, upper)

select and drop NULL s:

df
  .withColumn(colName, result)
  // If value is still NULL it means we didn't find refCol key in binCol keys.
  // To mimic .filter(col(refCol) === ...) we drop the rows
  .na.drop(Seq(colName))

This solution assumes that there are no NULL values in the colName at the beginning, but can be easily adjusted to handle cases where this assumption is not satisfied.

If the process is still unclear I'd recommend tracing it step-by-step with literals:

spark.range(1).select(binCol as "map").show(false)
+------------------------------------------------------------+
|map                                                         |
+------------------------------------------------------------+
|[bin1 -> [1.0, 2.0], bin2 -> [3.0, 4.0], bin3 -> [5.0, 6.0]]|
+------------------------------------------------------------+
spark.range(1).select(binCol(lit("bin1")) as "value").show(false)
+----------+
|value     |
+----------+
|[1.0, 2.0]|
+----------+
spark.range(1).select(binCol(lit("bin1"))("lower") as "value").show
+-----+
|value|
+-----+
|  1.0|
+-----+

and further referring to Querying Spark SQL DataFrame with complex types .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM