简体   繁体   中英

Transforming a Spark Dataframe Column into a Dataframe with just one line (ArrayType)

I have a dataframe that contains a column with complex objects in it:

+--------+
|col1    |
+--------+
|object1 |
|object2 | 
|object3 |    
+--------+

The schema of this object is pretty complex, something that looks like:

root:struct
    field1:string
    field2:decimal(38,18)
    object1:struct
        field3:string
        object2:struct
            field4:string
            field5:decimal(38,18)

What is the best way to group everything and transform it into an array?

eg:

+-----------------------------+
|col1                         |
+-----------------------------+
| [object1, object2, object3] |    
+-----------------------------+

I tried to generate an array from a column then create a dataframe from it:

final case class A(b: Array[Any])

val c = df.select("col1").collect().map(_(0)).toArray

df.sparkSession.createDataset(Seq(A(b = c)))

However, Spark doesn't like my Array[Any] trick:

java.lang.ClassNotFoundException: scala.Any

Any ideas?

Spark uses encoders for datatypes, this is the reason Any doesn't work.

If the schema of the complex object is fixed, you can define a case class with that schema and do the following,

case class C(... object1: A, object2: B ...)

val df = ???

val mappedDF = df.as[C] // this will map each complex object to case class

Next, you can use a UDF to change each C object to Seq(...) on row level. It'll look something like,

import org.apache.spark.sql.expressions.{UserDefinedFunction => UDF}
import org.apache.spark.sql.functions.col

 def convert: UDF =
    udf((complexObj: C) => Seq(complexObj.object1,complexObj.object2,complexObj.object3))

To use this UDF ,

mappedDF.withColumn("resultColumn", convert(col("col1")))

Note: Since not much info was provided about the schema, I've used generics like A and B. You will have to define all of these.

What is the best way to group everything and transform it into an array?

There is not even a good way to do it. Please remember that Spark cannot distribute individual rows. The result will be:

  • Processed sequentially.
  • Possibly to large to be stored in memory.

Other than the above you can just collect_list :

import org.apache.spark.sql.functions.{col, collect_list}

df.select(collect_list(col("col1"))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM