简体   繁体   中英

How to map struct in DataFrame to case class?

At some point in my application, I have a DataFrame with a Struct field created from a case class. Now I want to cast/map it back to the case class type:

import spark.implicits._
case class Location(lat: Double, lon: Double)

scala> Seq((10, Location(35, 25)), (20, Location(45, 35))).toDF
res25: org.apache.spark.sql.DataFrame = [_1: int, _2: struct<lat: double, lon: double>]

scala> res25.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)

And basic:

res25.map(r => {
   Location(r.getStruct(1).getDouble(0), r.getStruct(1).getDouble(1))
}).show(1)

Looks really dirty Is there any simpler way?

In Spark 1.6+ if you want to retain the type information retained, then use Dataset (DS), not DataFrame (DF).

import spark.implicits._
case class Location(lat: Double, lon: Double)

scala> Seq((10, Location(35, 25)), (20, Location(45, 35))).toDS
res25: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]

scala> res25.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)

It will give you Dataset[(Int, Location)] . Now, if you want to get back to it's case class origin again, then simply do like this:

scala> res25.map(r => r._2).show(1)
+----+----+
| lat| lon|
+----+----+
|35.0|25.0|
+----+----+

But, if you want to stick to DataFrame API, due it's to dynamic type nature, then you have to you have to code it like this:

scala> res25.select("_2.*").map(r => Location(r.getDouble(0), r.getDouble(1))).show(1)
+----+----+
| lat| lon|
+----+----+
|35.0|25.0|
+----+----+

You could also use the extractor pattern in Row that would give you similar results, using more idiomatic scala:

scala> res25.map { row =>
  (row: @unchecked) match {
    case Row(a: Int, Row(b: Double, c: Double)) => (a, Location(b, c))
  }
}
res26: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]
scala> res26.collect()
res27: Array[(Int, Location)] = Array((10,Location(35.0,25.0)), (20,Location(45.0,35.0)))

I think the other answers nailed it, but perhaps they may need some other wording.

In short, it's not possible to use case classes in DataFrames since they don't case about case classes and use RowEncoder to map internal SQL types to a Row .

As the other answers said, you have to turn Row -based DataFrame into a Dataset using as operator.

val df = Seq((10, Location(35, 25)), (20, Location(45, 35))).toDF
scala> val ds = df.as[(Int, Location)]
ds: org.apache.spark.sql.Dataset[(Int, Location)] = [_1: int, _2: struct<lat: double, lon: double>]

scala> ds.show
+---+-----------+
| _1|         _2|
+---+-----------+
| 10|[35.0,25.0]|
| 20|[45.0,35.0]|
+---+-----------+

scala> ds.printSchema
root
 |-- _1: integer (nullable = false)
 |-- _2: struct (nullable = true)
 |    |-- lat: double (nullable = false)
 |    |-- lon: double (nullable = false)

scala> ds.map[TAB pressed twice]

def map[U](func: org.apache.spark.api.java.function.MapFunction[(Int, Location),U],encoder: org.apache.spark.sql.Encoder[U]): org.apache.spark.sql.Dataset[U]
def map[U](func: ((Int, Location)) => U)(implicit evidence$6: org.apache.spark.sql.Encoder[U]): org.apache.spark.sql.Dataset[U]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM