I am using spark 2.3 in my scala application. I have a dataframe which create from spark sql that name is sqlDF in the sample code which I shared. I have a string list that has the items below
List[] stringList items
-9,-8,-7,-6
I want to replace all values that match with this lists item in all columns in dataframe to 0.
Initial dataframe
column1 | column2 | column3
1 |1 |1
2 |-5 |1
6 |-6 |1
-7 |-8 |-7
It must return to
column1 | column2 | column3
1 |1 |1
2 |-5 |1
6 |0 |1
0 |0 |0
For this I am itarating the query below for all columns (more than 500) in sqlDF.
sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList:_*), 0).otherwise(col(currColumnName)))
But getting the error below, by the way if I choose only one column for iterating it works, but if I run the code above for 500 columns iteration it fails
Exception in thread "streaming-job-executor-0" java.lang.StackOverflowError at scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:57) at scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:52) at scala.collection.TraversableLike$class.builder$1(TraversableLike.scala:229) at scala.collection.TraversableLike$class.map(TraversableLike.scala:233) at scala.collection.immutable.List.map(List.scala:285) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:333) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
What is the thing that I am missing?
Here is a different approach applying left anti join
between columnX
and X
where X is your list of items transferred into a dataframe. The left anti join will return all the items not present in X
, the results we concatenate them all together through an outer join (which can be replaced with left join for better performance, this though will exclude records with all zeros ie id == 3) based on the id assigned with monotonically_increasing_id:
import org.apache.spark.sql.functions.{monotonically_increasing_id, col}
val df = Seq(
(1, 1, 1),
(2, -5, 1),
(6, -6, 1),
(-7, -8, -7))
.toDF("c1", "c2", "c3")
.withColumn("id", monotonically_increasing_id())
val exdf = Seq(-9, -8, -7, -6).toDF("x")
df.columns.map{ c =>
df.select("id", c).join(exdf, col(c) === $"x", "left_anti")
}
.reduce((df1, df2) => df1.join(df2, Seq("id"), "outer"))
.na.fill(0)
.show
Output:
+---+---+---+---+
| id| c1| c2| c3|
+---+---+---+---+
| 0| 1| 1| 1|
| 1| 2| -5| 1|
| 3| 0| 0| 0|
| 2| 6| 0| 1|
+---+---+---+---+
foldLeft
works perfect for your case here as below
val df = spark.sparkContext.parallelize(Seq(
(1, 1, 1),
(2, -5, 1),
(6, -6, 1),
(-7, -8, -7)
)).toDF("a", "b", "c")
val list = Seq(-7, -8, -9)
val resultDF = df.columns.foldLeft(df) { (acc, name) => {
acc.withColumn(name, when(col(name).isin(list: _*), 0).otherwise(col(name)))
}
}
Output:
+---+---+---+
|a |b |c |
+---+---+---+
|1 |1 |1 |
|2 |-5 |1 |
|6 |-6 |1 |
|0 |0 |0 |
+---+---+---+
I would suggest you to broadcast the list of String :
val stringList=sc.broadcast(<Your List of List[String]>)
After that use this :
sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList.value:_*), 0).otherwise(col(currColumnName)))
Make sure your currColumnName also is in String Format. Comparison should be String to String
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.