简体   繁体   中英

How to convert Pandas Dataframe coming from RDD.mapPartitions() into Spark DataFrame?

I've got a Python function that returns a Pandas DataFrame. I'm calling this function in Spark 2.2.0 using pyspark's RDD.mapPartitions() . But I can't convert the RDD returned by mapPartitions() into a Spark DataFrame. Pandas generates this error:

ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

Simple code which illustrates the problem:

import pandas as pd

def func(data):
    pdf = pd.DataFrame(list(data), columns=("A", "B", "C"))
    pdf += 10 # Add 10 to every value. The real function is a lot more complex!
    return [pdf]

pdf = pd.DataFrame([(1.87, 0.6, 7.1), (-0.3, 0.1, 8.2), (2.8, 0.3, 6.1), (-0.2, 0.5, 5.9)], columns=("A", "B", "C"))

sdf = spark.createDataFrame(pdf)
sdf.show()
rddIn = sdf.rdd

for i in rddIn.collect():
    print(i)

result = rddIn.mapPartitions(func)

for i in result.collect():
    print(i)

resDf = spark.createDataFrame(result) # --> ValueError!
resDf.show()

The output is:

+----+---+---+
|   A|  B|  C|
+----+---+---+
|1.87|0.6|7.1|
|-0.3|0.1|8.2|
| 2.8|0.3|6.1|
|-0.2|0.5|5.9|
+----+---+---+
Row(A=1.87, B=0.6, C=7.1)
Row(A=-0.3, B=0.1, C=8.2)
Row(A=2.8, B=0.3, C=6.1)
Row(A=-0.2, B=0.5, C=5.9)
       A     B     C
0  11.87  10.6  17.1
     A     B     C
0  9.7  10.1  18.2
      A     B     C
0  12.8  10.3  16.1
     A     B     C
0  9.8  10.5  15.9

but the second to last line produces the ValueError mentioned above. I really want resDf.show() to look exactly the same as sdf.show() except with 10 added to every value in the table. Ideally the result RDD should have the same structure as rddIn , the RDD going in to mapPartitions() .

You have to convert data to standard Python types and flatten:

resDf = spark.createDataFrame(
    result.flatMap(lambda df: (r.tolist() for r in df.to_records()))
)

resDF.show()
# +---+------------------+----+----+                                              
# | _1|                _2|  _3|  _4|
# +---+------------------+----+----+
# |  0|11.870000000000001|10.6|17.1|
# |  0|               9.7|10.1|18.2|
# |  0|              12.8|10.3|16.1|
# |  0|               9.8|10.5|15.9|
# +---+------------------+----+----+

If you use Spark 2.3 this should also work

from pyspark.sql.functions import pandas_udf, spark_partition_id
from pyspark.sql.functions import PandasUDFType

@pandas_udf(sdf.schema, functionType=PandasUDFType.GROUPED_MAP)  
def func(pdf):
    pdf += 10 
    return pdf

sdf.groupBy(spark_partition_id().alias("_pid")).apply(func)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM