简体   繁体   中英

How to replace multiple rows, duplicated in one column, with a single row of nulls in a Spark dataframe

I have a Spark dataframe like this:

+------+---------+---------+---------+---------+
| name | metric1 | metric2 | metric3 | metric4 | 
+------+---------+---------+---------+---------+
| a    |       1 |       2 |       3 |       4 | 
| b    |       1 |       2 |       3 |       4 | 
| c    |       3 |       1 |       5 |       4 | 
| a    |       3 |       3 |       3 |       3 | 
+------+---------+---------+---------+---------+

For any duplicate names that appear, I want to replace the multiple rows with a single row containing nulls, so desired output is:

+------+---------+---------+---------+---------+
| name | metric1 | metric2 | metric3 | metric4 | 
+------+---------+---------+---------+---------+
| a    |    null |    null |    null |    null | 
| b    |       1 |       2 |       3 |       4 | 
| c    |       3 |       1 |       5 |       4 | 
+------+---------+---------+---------+---------+

The following works:

import org.apache.spark.sql.functions._

val df = Seq(
  ("a", 1, 2, 3, 4), ("b", 1, 2, 3, 4), ("c", 3, 1, 5, 4), ("a", 3, 3, 3, 3)
).toDF("name", "metric1", "metric2", "metric3", "metric4")

val newDf = df
.groupBy(col("name"))
.agg(
  min(col("metric1")).as("metric1"),
  min(col("metric2")).as("metric2"),
  min(col("metric3")).as("metric3"),
  min(col("metric4")).as("metric4"),
  count(col("name")).as("NumRecords")
)
.withColumn("metric1", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric1")))
.withColumn("metric2", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric2")))
.withColumn("metric3", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric3")))
.withColumn("metric4", when(col("NumRecords") !== 1, lit(null)).otherwise(col("metric4")))
.drop("NumRecords")

but surely there has got to be a better way...

    scala> val df = Seq(("a", 1, 2, 3, 4), ("b", 1, 2, 3, 4), ("c", 3, 1, 5, 4), ("a", 3, 3, 3, 3)).toDF("name", "metric1", "metric2", "metric3", "metric4")
    scala> val newDf = df.groupBy(col("name")).agg(min(col("metric1")).as("metric1"),min(col("metric2")).as("metric2"),min(col("metric3")).as("metric3"),min(col("metric4")).as("metric4"),count(col("name")).as("NumRecords"))
    scala> val colArr2 = df.columns.diff(Array("name"))
    scala> val reqDF = colArr2.foldLeft(newDf){
     (df,colName)=>
     df.withColumn(colName,when(col("NumRecords") =!= "1",lit(null)).otherwise(col(colName)))
     }.drop("NumRecords")

    scala> reqDF.show
    +----+-------+-------+-------+-------+
    |name|metric1|metric2|metric3|metric4|
    +----+-------+-------+-------+-------+
    |   c|      3|      1|      5|      4|
    |   b|      1|      2|      3|      4|
    |   a|   null|   null|   null|   null|
    +----+-------+-------+-------+-------+

Please try like above.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM