I'm looking for the most elegant and effective way to convert a dictionary to Spark Data Frame with PySpark with the described output and input.
Input:
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
Output:
vals | keys
------------
"val1" | ["key1"]
"val2" | ["key1"]
"val3" | ["key1", "key2"]
"val4" | ["key2"]
"val5" | ["key2"]
edit: I prefer to do most of the manipulation with Spark. maybe first convert it to
vals | keys
------------
"val1" | "key1"
"val2" | "key1"
"val3" | "key1"
"Val3" | "key2"
"val4" | "key2"
"val5" | "key2"
First construct the Spark dataframe from dictionary items. Then explode
the vals
and then group by the vals
and collect
all keys
that contain that value.
from pyspark.sql import functions as F
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
df = spark.createDataFrame(data.items(), ("keys", "vals"))
(df.withColumn("vals", F.explode("vals"))
.groupBy("vals").agg(F.collect_list("keys").alias("keys"))
).show()
"""
+----+------------+
|vals| keys|
+----+------------+
|val1| [key1]|
|val3|[key1, key2]|
|val2| [key1]|
|val4| [key2]|
|val5| [key2]|
+----+------------+
"""
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
df = spark.createDataFrame(data.items(), ("keys", "vals"))
df
from pyspark.sql.functions import *
from pyspark.sql.types import *
def flatten_test(df, sep="_"):
"""Returns a flattened dataframe.
.. versionadded:: x.X.X
Parameters
----------
sep : str
Delimiter for flatted columns. Default `_`
Notes
-----
Don`t use `.` as `sep`
It won't work on nested data frames with more than one level.
And you will have to use `columns.name`.
Flattening Map Types will have to find every key in the column.
This can be slow.
Examples
--------
data_mixed = [
{
"state": "Florida",
"shortname": "FL",
"info": {"governor": "Rick Scott"},
"counties": [
{"name": "Dade", "population": 12345},
{"name": "Broward", "population": 40000},
{"name": "Palm Beach", "population": 60000},
],
},
{
"state": "Ohio",
"shortname": "OH",
"info": {"governor": "John Kasich"},
"counties": [
{"name": "Summit", "population": 1234},
{"name": "Cuyahoga", "population": 1337},
],
},
]
data_mixed = spark.createDataFrame(data=data_mixed)
data_mixed.printSchema()
root
|-- counties: array (nullable = true)
| |-- element: map (containsNull = true)
| | |-- key: string
| | |-- value: string (valueContainsNull = true)
|-- info: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
data_mixed_flat = flatten_test(df, sep=":")
data_mixed_flat.printSchema()
root
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
|-- counties:name: string (nullable = true)
|-- counties:population: string (nullable = true)
|-- info:governor: string (nullable = true)
data = [
{
"id": 1,
"name": "Cole Volk",
"fitness": {"height": 130, "weight": 60},
},
{"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
{
"id": 2,
"name": "Faye Raker",
"fitness": {"height": 130, "weight": 60},
},
]
df = spark.createDataFrame(data=data)
df.printSchema()
root
|-- fitness: map (nullable = true)
| |-- key: string
| |-- value: long (valueContainsNull = true)
|-- id: long (nullable = true)
|-- name: string (nullable = true)
df_flat = flatten_test(df, sep=":")
df_flat.printSchema()
root
|-- id: long (nullable = true)
|-- name: string (nullable = true)
|-- fitness:height: long (nullable = true)
|-- fitness:weight: long (nullable = true)
data_struct = [
(("James",None,"Smith"),"OH","M"),
(("Anna","Rose",""),"NY","F"),
(("Julia","","Williams"),"OH","F"),
(("Maria","Anne","Jones"),"NY","M"),
(("Jen","Mary","Brown"),"NY","M"),
(("Mike","Mary","Williams"),"OH","M")
]
schema = StructType([
StructField('name', StructType([
StructField('firstname', StringType(), True),
StructField('middlename', StringType(), True),
StructField('lastname', StringType(), True)
])),
StructField('state', StringType(), True),
StructField('gender', StringType(), True)
])
df_struct = spark.createDataFrame(data = data_struct, schema = schema)
df_struct.printSchema()
root
|-- name: struct (nullable = true)
| |-- firstname: string (nullable = true)
| |-- middlename: string (nullable = true)
| |-- lastname: string (nullable = true)
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
df_struct_flat = flatten_test(df_struct, sep=":")
df_struct_flat.printSchema()
root
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
|-- name:firstname: string (nullable = true)
|-- name:middlename: string (nullable = true)
|-- name:lastname: string (nullable = true)
"""
# compute Complex Fields (Arrays, Structs and Maptypes) in Schema
complex_fields = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType
]
)
while len(complex_fields) != 0:
col_name = list(complex_fields.keys())[0]
# print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if type(complex_fields[col_name]) == StructType:
expanded = [
col(col_name + "." + k).alias(col_name + sep + k)
for k in [n.name for n in complex_fields[col_name]]
]
df = df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif type(complex_fields[col_name]) == ArrayType:
df = df.withColumn(col_name, explode_outer(col_name))
# if MapType then convert all sub element to columns.
# i.e. flatten
elif type(complex_fields[col_name]) == MapType:
keys_df = df.select(explode_outer(map_keys(col(col_name)))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(
map(
lambda f: col(col_name).getItem(f).alias(str(col_name + sep + f)),
keys,
)
)
drop_column_list = [col_name]
df = df.select(
[
col_name
for col_name in df.columns
if col_name not in drop_column_list
]
+ key_cols
)
# recompute remaining Complex Fields in Schema
complex_fields = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType
]
)
return df
df_falt = flatten_test(df)
df_falt
keys | vals |
---|---|
key1 | val1 |
key1 | val2 |
key1 | val3 |
key2 | val3 |
key2 | val4 |
key2 | val5 |
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.