简体   繁体   English

如何将地图数组转换为 Scala/Spark 中的单个 map 列?

[英]How to convert array of maps to single map column in Scala/Spark?

I have a dataframe -我有一个 dataframe -

|-- CLICK_THROUGH_RATE_MAP: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: double (valueContainsNull = true)

+---------------------------------+
|CLICK_THROUGH_RATE_MAP           |
+---------------------------------+
|[[web -> 2.47]]                  |
|[[mobile -> 2.36], [web -> 3.78]]|
+---------------------------------+

How can I convert the array of map column to a single如何将 map 列的数组转换为单个

+---------------------------------+
|CLICK_THROUGH_RATE_MAP           |
+---------------------------------+
|[web -> 2.47]                    |
|[mobile -> 2.36, web -> 3.78]    |
+---------------------------------+

CLICK_THROUGH_RATE_MAP: map (nullable = false)
 |    |-- key: string
 |    |-- value: double (valueContainsNull = true)

I tried with UDF =我试过UDF =

import org.apache.spark.sql.functions.udf
val joinMap = udf { values: Seq[Map[String,Double]] => values.flatten.toMap }

val df2 = df
.select("CLICK_THROUGH_RATE_MAP")
.withColumn("map_output", joinMap(col("CLICK_THROUGH_RATE_MAP")))
df2.show(false)

But this gives me -但这给了我-

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 362.0 failed 4 times, most recent failure: Lost task 0.3 in stage 362.0 (TID 278364, ip-10-0-1-112.ec2.internal, executor 1): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
    at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2287)
    at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1417)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2347)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:464)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
    at sun.reflect.GeneratedMethodAccessor179.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2232)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:464)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
    at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
    at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:83)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
  at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2041)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2029)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2028)
  at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
  at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
  at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2028)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
  at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
  at scala.Option.foreach(Option.scala:257)
  at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2262)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2211)
  at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2200)
  at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
  at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
  at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
  at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:401)
  at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
  at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3389)
  at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
  at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
  at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3370)
  at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
  at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
  at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
  at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
  at org.apache.spark.sql.Dataset.head(Dataset.scala:2550)
  at org.apache.spark.sql.Dataset.take(Dataset.scala:2764)
  at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
  at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
  at org.apache.spark.sql.Dataset.show(Dataset.scala:753)
  at org.apache.spark.sql.Dataset.show(Dataset.scala:730)
  ... 53 elided
Caused by: java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
  at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2287)
  at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1417)
  at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2347)
  at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
  at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
  at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
  at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
  at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
  at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
  at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
  at java.io.ObjectInputStream.readObject(ObjectInputStream.java:464)
  at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
  at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:490)
  at sun.reflect.GeneratedMethodAccessor179.invoke(Unknown Source)
  at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  at java.lang.reflect.Method.invoke(Method.java:498)
  at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1170)
  at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2232)
  at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
  at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
  at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
  at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
  at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
  at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
  at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2341)
  at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2265)
  at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2123)
  at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1624)
  at java.io.ObjectInputStream.readObject(ObjectInputStream.java:464)
  at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
  at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
  at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:83)
  at org.apache.spark.scheduler.Task.run(Task.scala:123)
  at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
  ... 3 more

You will need to explode the array column.您将需要explode阵列列。

.withColumn("map_output", explode(col("CLICK_THROUGH_RATE_MAP")))

explode will generate multiple rows for the same row if an array will hold more than one value .如果数组包含多个value explode将为同一row生成多rows

And if it is guaranteed that column array will always hold one single value, you can use either of the below methods too:如果保证列数组始终包含一个值,您也可以使用以下任一方法:

.withColumn("map_output", col("CLICK_THROUGH_RATE_MAP")(0))

//@since 2.4.0
.withColumn("map_output", element_at(col("CLICK_THROUGH_RATE_MAP"), 0))

This can be done df.map(_.reduce((a, b) => a ++ b))这可以做到 df.map(_.reduce((a, b) => a ++ b))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM