简体   繁体   English

在Scala中将一个DataFrame展平,其中包含不同的DataTypes

[英]Flatten a DataFrame in Scala with different DataTypes inside

As you may know, a DataFrame can contain fields which are complex types, like structures (StructType) or arrays (ArrayType). 如您所知,DataFrame可以包含复杂类型的字段,例如结构(StructType)或数组(ArrayType)。 You may need, as in my case, to map all the DataFrame data to a Hive table, with simple type fields (String, Integer...). 与我的情况一样,您可能需要使用简单的类型字段(字符串,整数...)将所有DataFrame数据映射到Hive表。 I've been struggling with this issue for a long time, and I've finally found a solution I want to share. 我已经为这个问题苦苦挣扎了很长时间,终于找到了我想分享的解决方案。 Also, I'm sure it could be improved, so feel free to reply with your own suggestions. 另外,我相信它会得到改善,请随时提出您自己的建议。

It's based on this thread , but also works for ArrayType elements, not only StructType ones. 它基于此线程 ,但也适用于ArrayType元素,不仅适用于StructType元素。 It is a tail recursive function which receives a DataFrame, and returns it flattened. 它是一个尾部递归函数,它接收一个DataFrame并将其展平。

def flattenDf(df: DataFrame): DataFrame = {
  var end = false
  var i = 0
  val fields = df.schema.fields
  val fieldNames = fields.map(f => f.name)
  val fieldsNumber = fields.length

  while (!end) {
    val field = fields(i)
    val fieldName = field.name

    field.dataType match {
      case st: StructType =>
        val childFieldNames = st.fieldNames.map(n => fieldName + "." + n)
        val newFieldNames = fieldNames.filter(_ != fieldName) ++ childFieldNames
        val newDf = df.selectExpr(newFieldNames: _*)
        return flattenDf(newDf)
      case at: ArrayType =>
        val fieldNamesExcludingArray = fieldNames.filter(_ != fieldName)
        val fieldNamesAndExplode = fieldNamesExcludingArray ++ Array(s"explode($fieldName) as a")
        val fieldNamesToSelect = fieldNamesExcludingArray ++ Array("a.*")
        val explodedDf = df.selectExpr(fieldNamesAndExplode: _*)
        val explodedAndSelectedDf = explodedDf.selectExpr(fieldNamesToSelect: _*)
        return flattenDf(explodedAndSelectedDf)
      case _ => Unit
    }

    i += 1
    end = i >= fieldsNumber
  }
  df
}

val df = Seq(("1", (2, (3, 4)),Seq(1,2))).toDF() val df = Seq((“ 1”,(2,(3,4)),Seq(1,2)))。toDF()

df.printSchema df.printSchema

root
 |-- _1: string (nullable = true)
 |-- _2: struct (nullable = true)
 |    |-- _1: integer (nullable = false)
 |    |-- _2: struct (nullable = true)
 |    |    |-- _1: integer (nullable = false)
 |    |    |-- _2: integer (nullable = false)
 |-- _3: array (nullable = true)
 |    |-- element: integer (containsNull = false)


def flattenSchema(schema: StructType, fieldName: String = null) : Array[Column] = {
   schema.fields.flatMap(f => {
     val cols = if (fieldName == null) f.name else (fieldName + "." + f.name)
     f.dataType match {
       case structType: StructType => fattenSchema(structType, cols)
       case arrayType: ArrayType => Array(explode(col(cols)))
       case _ => Array(col(cols))
     }
   })
 }

df.select(flattenSchema(df.schema) :_*).printSchema df.select(flattenSchema(df.schema):_ *)。printSchema

root
 |-- _1: string (nullable = true)
 |-- _1: integer (nullable = true)
 |-- _1: integer (nullable = true)
 |-- _2: integer (nullable = true)
 |-- col: integer (nullable = false)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM