繁体   English   中英

PySpark 将“map”类型的列转换为 dataframe 中的多个列

[英]PySpark converting a column of type 'map' to multiple columns in a dataframe

输入

我有一个map类型的列Parameters ,其形式为:

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
df = sqlContext.createDataFrame(d)

df.collect()
# [Row(Parameters={'foo': '1', 'bar': '2', 'baz': 'aaa'})]

df.printSchema()
# root
#  |-- Parameters: map (nullable = true)
#  |    |-- key: string
#  |    |-- value: string (valueContainsNull = true)

Output

我想在 PySpark 中重塑它,以便所有键( foobar等)都成为列,即:

[Row(foo='1', bar='2', baz='aaa')]

使用withColumn有效:

(df
 .withColumn('foo', df.Parameters['foo'])
 .withColumn('bar', df.Parameters['bar'])
 .withColumn('baz', df.Parameters['baz'])
 .drop('Parameters')
).collect()

我需要一个没有明确提及列名的解决方案,因为我有几十个。

由于MapType键不是架构的一部分,因此您必须首先收集它们,例如:

from pyspark.sql.functions import explode

keys = (df
    .select(explode("Parameters"))
    .select("key")
    .distinct()
    .rdd.flatMap(lambda x: x)
    .collect())

当您拥有这些后,剩下的就是简单的选择:

from pyspark.sql.functions import col

exprs = [col("Parameters").getItem(k).alias(k) for k in keys]
df.select(*exprs)

高性能解决方案

问题约束之一是动态确定列名,这很好,但请注意,这可能非常慢。 下面介绍了如何避免键入和编写可快速执行的代码。

cols = list(map(
    lambda f: F.col("Parameters").getItem(f).alias(str(f)),
    ["foo", "bar", "baz"]))
df.select(cols).show()
+---+---+---+
|foo|bar|baz|
+---+---+---+
|  1|  2|aaa|
+---+---+---+

请注意,这将运行单个选择操作。 不要多次运行withColumn ,因为那会比较慢。

只有当您知道所有地图键时,才有可能快速解决。 如果您不知道地图键的所有唯一值,您将需要恢复到较慢的解决方案。

较慢的解决方案

接受的答案很好。 我的解决方案性能更高一些,因为它不调用.rddflatMap()

import pyspark.sql.functions as F

d = [{'Parameters': {'foo': '1', 'bar': '2', 'baz': 'aaa'}}]
df = spark.createDataFrame(d)

keys_df = df.select(F.explode(F.map_keys(F.col("Parameters")))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(map(lambda f: F.col("Parameters").getItem(f).alias(str(f)), keys))
df.select(key_cols).show()
+---+---+---+
|bar|foo|baz|
+---+---+---+
|  2|  1|aaa|
+---+---+---+

将结果收集到驱动程序节点可能是性能瓶颈。 最好将此代码list(map(lambda row: row[0], keys_df.collect()))作为单独的命令执行,以确保它不会运行得太慢。

性能方面,而不是硬编码列名,使用这个:

from pyspark.sql import functions as F

df = df.withColumn("_c", F.to_json("Parameters"))
json_schema = spark.read.json(df.rdd.map(lambda r: r._c)).schema
df = df.withColumn("_c", F.from_json("_c", json_schema))
df = df.select("_c.*")

df.show()
# +----+----+---+
# | bar| baz|foo|
# +----+----+---+
# |   2| aaa|  1|
# |null|null|  1|
# +----+----+---+

它既不使用distinct也不使用collect 它曾经调用rdd ,以便提取的模式具有合适的格式以在from_json中使用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM