简体   繁体   中英

How to get all columns in Spark DataFrame recursively

I want to get all columns of DataFrame. If DataFrame has a flat structure (no nested StructTypes) df.columns produces correct result. I want to return all nested column names also, eg

Given

val schema = StructType(
  StructField("name", StringType) ::
  StructField("nameSecond", StringType) ::
  StructField("nameDouble", StringType) ::
  StructField("someStruct", StructType(
    StructField("insideS", StringType)::
    StructField("insideD", DoubleType)::
    Nil
  )) ::
  Nil
)
val rdd = spark.sparkContext.emptyRDD[Row]
val df = spark.createDataFrame(rdd, schema)

I want to get

Seq("name", "nameSecond", "nameDouble", "someStruct", "insideS", "insideD")

You can use this recursive function to traverse the schema:

def flattenSchema(schema: StructType): Seq[String] = {
  schema.fields.flatMap {
    case StructField(name, inner: StructType, _, _) => Seq(name) ++ flattenSchema(inner)
    case StructField(name, _, _, _) => Seq(name)
  }
}

println(flattenSchema(schema)) 
// prints: ArraySeq(name, nameSecond, nameDouble, someStruct, insideS, insideD)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM