简体   繁体   中英

using sparksql and spark dataframe How can we find the COLUMN NAME based on the minimum value in a row

i have a dataframe df . its having 4 columns

+-------+-------+-------+-------+  
| dist1 | dist2 | dist3 | dist4 |
+-------+-------+-------+-------+  
|  42   |  53   |  24   |  17   |
+-------+-------+-------+-------+  

output i want is

dist4

seems easy but i did not find any proper solution using dataframe or sparksql query

You may use least function as

select least(dist1,dist2,dist3,dist4) as min_dist
  from yourTable;

For the opposite cases greatest may be used.

EDIT : To detect column names the following maybe used to get rows

select inline(array(struct(42, 'dist1'), struct(53, 'dist2'), 
                    struct(24, 'dist3'), struct(17, 'dist4') ))

42  dist1
53  dist2
24  dist3
17  dist4 

and then min function may be applied to get dist4

Try this,

df.show
+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|  1|  2|  3|  4|
|  5|  4|  3|  1|
+---+---+---+---+

val temp_df = df.columns.foldLeft(df) { (acc: DataFrame, colName: String) => acc.withColumn(colName, concat(col(colName), lit(","+colName)))}

val minval = udf((ar: Seq[String]) => ar.min.split(",")(1))

val result = temp_df.withColumn("least", split(concat_ws(":",x.columns.map(col(_)):_*),":")).withColumn("least_col", minval(col("least")))

result.show
+---+---+---+---+--------------------+---------+
|  A|  B|  C|  D|               least|least_col|
+---+---+---+---+--------------------+---------+
|1,A|2,B|3,C|4,D|[1,A, 2,B, 3,C, 4,D]|        A|
|5,A|4,B|3,C|1,D|[5,A, 4,B, 3,C, 1,D]|        D|
+---+---+---+---+--------------------+---------+

RDD way and without udf()s.

scala> val df = Seq((1,2,3,4),(5,4,3,1)).toDF("A","B","C","D")
df: org.apache.spark.sql.DataFrame = [A: int, B: int ... 2 more fields]

scala> val df2 = df.withColumn("arr", array(df.columns.map(col(_)):_*))
df2: org.apache.spark.sql.DataFrame = [A: int, B: int ... 3 more fields]

scala>  val rowarr = df.columns
rowarr: Array[String] = Array(A, B, C, D)

scala> val rdd1 = df2.rdd.map( x=> {val p = x.getAs[WrappedArray[Int]]("arr").toArray; val q=rowarr(p.indexWhere(_==p.min));Row.merge(x,Row(q)) })
rdd1: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = MapPartitionsRDD[83] at map at <console>:47

scala> spark.createDataFrame(rdd1,df2.schema.add(StructField("mincol",StringType))).show
+---+---+---+---+------------+------+
|  A|  B|  C|  D|         arr|mincol|
+---+---+---+---+------------+------+
|  1|  2|  3|  4|[1, 2, 3, 4]|     A|
|  5|  4|  3|  1|[5, 4, 3, 1]|     D|
+---+---+---+---+------------+------+


scala>

you can do something like,

import org.apache.spark.sql.functions._

val cols = df.columns
val u1 = udf((s: Seq[Int]) => cols(s.zipWithIndex.min._2))

df.withColumn("res", u1(array("*")))

You could access the rows schema, retrieve a list of names out of there and access the rows value by name and then figure it out that way.

See: https://spark.apache.org/docs/2.3.2/api/scala/index.html#org.apache.spark.sql.Row

It would look roughly like this

dataframe.map(
    row => {
        val schema = row.schema
        val fieldNames:List[String] =  ??? //extract names from schema
        fieldNames.foldLeft(("", 0))(???) // retrieve field value using it's name and retain maximum
    }
)

This would yield a Dataset[String]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM