简体   繁体   中英

How to convert Columns to rows in Spark scala or spark sql?

I have the Data like this.

+------+------+------+----------+----------+----------+----------+----------+----------+
| Col1 | Col2 | Col3 | Col1_cnt | Col2_cnt | Col3_cnt | Col1_wts | Col2_wts | Col3_wts |
+------+------+------+----------+----------+----------+----------+----------+----------+
| AAA  | VVVV | SSSS |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
| BBB  | BBBB | TTTT |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
| CCC  | DDDD | YYYY |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
+------+------+------+----------+----------+----------+----------+----------+----------+

I have tried but I am not getting any help here.

val df = Seq(("G",Some(4),2,None),("H",None,4,Some(5))).toDF("A","X","Y", "Z")

I want the output in the form of below table

+-----------+---------+---------+
| Cols_name | Col_cnt | Col_wts |
+-----------+---------+---------+
| Col1      |       3 |     0.5 |
| Col2      |       4 |     0.4 |
| Col3      |       5 |     0.6 |
+-----------+---------+---------+

Here's a general approach for transposing a DataFrame:

  1. For each of the pivot columns (say c1 , c2 , c3 ), combine the column name and associated value columns into a struct (eg struct(lit(c1), c1_cnt, c1_wts) )
  2. Put all these struct -typed columns into an array which is then explode -ed into rows of struct columns
  3. Group by the pivot column name to aggregate the associated struct elements

The following sample code has been generalized to handle an arbitrary list of columns to be transposed:

import org.apache.spark.sql.functions._
import spark.implicits._

val df = Seq(
  ("AAA", "VVVV", "SSSS", 3, 4, 5, 0.5, 0.4, 0.6),
  ("BBB", "BBBB", "TTTT", 3, 4, 5, 0.5, 0.4, 0.6),
  ("CCC", "DDDD", "YYYY", 3, 4, 5, 0.5, 0.4, 0.6)
).toDF("c1", "c2", "c3", "c1_cnt", "c2_cnt", "c3_cnt", "c1_wts", "c2_wts", "c3_wts")

val pivotCols = Seq("c1", "c2", "c3")

val valueColSfx = Seq("_cnt", "_wts")

val arrStructs = pivotCols.map{ c => struct(
    Seq(lit(c).as("_pvt")) ++
      valueColSfx.map((c, _)).map{ case (p, s) => col(p + s).as(s) }: _*
  ).as(c + "_struct")
}

val valueColAgg = valueColSfx.map(s => first($"struct_col.$s").as(s + "_first"))

df.
  select(array(arrStructs: _*).as("arr_structs")).
  withColumn("struct_col", explode($"arr_structs")).
  groupBy($"struct_col._pvt").agg(valueColAgg.head, valueColAgg.tail: _*).
  show
// +----+----------+----------+
// |_pvt|_cnt_first|_wts_first|
// +----+----------+----------+
// |  c1|         3|       0.5|
// |  c3|         5|       0.6|
// |  c2|         4|       0.4|
// +----+----------+----------+

Note that function first is used in the above example, but it could be any other aggregate function (eg avg , max , collect_list ) depending on the specific business requirement.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM