I've got a Python function that returns a Pandas DataFrame. I'm calling this function in Spark 2.2.0 using pyspark's RDD.mapPartitions()
. But I can't convert the RDD returned by mapPartitions()
into a Spark DataFrame. Pandas generates this error:
ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Simple code which illustrates the problem:
import pandas as pd
def func(data):
pdf = pd.DataFrame(list(data), columns=("A", "B", "C"))
pdf += 10 # Add 10 to every value. The real function is a lot more complex!
return [pdf]
pdf = pd.DataFrame([(1.87, 0.6, 7.1), (-0.3, 0.1, 8.2), (2.8, 0.3, 6.1), (-0.2, 0.5, 5.9)], columns=("A", "B", "C"))
sdf = spark.createDataFrame(pdf)
sdf.show()
rddIn = sdf.rdd
for i in rddIn.collect():
print(i)
result = rddIn.mapPartitions(func)
for i in result.collect():
print(i)
resDf = spark.createDataFrame(result) # --> ValueError!
resDf.show()
The output is:
+----+---+---+
| A| B| C|
+----+---+---+
|1.87|0.6|7.1|
|-0.3|0.1|8.2|
| 2.8|0.3|6.1|
|-0.2|0.5|5.9|
+----+---+---+
Row(A=1.87, B=0.6, C=7.1)
Row(A=-0.3, B=0.1, C=8.2)
Row(A=2.8, B=0.3, C=6.1)
Row(A=-0.2, B=0.5, C=5.9)
A B C
0 11.87 10.6 17.1
A B C
0 9.7 10.1 18.2
A B C
0 12.8 10.3 16.1
A B C
0 9.8 10.5 15.9
but the second to last line produces the ValueError
mentioned above. I really want resDf.show()
to look exactly the same as sdf.show()
except with 10 added to every value in the table. Ideally the result
RDD should have the same structure as rddIn
, the RDD going in to mapPartitions()
.
You have to convert data to standard Python types and flatten:
resDf = spark.createDataFrame(
result.flatMap(lambda df: (r.tolist() for r in df.to_records()))
)
resDF.show()
# +---+------------------+----+----+
# | _1| _2| _3| _4|
# +---+------------------+----+----+
# | 0|11.870000000000001|10.6|17.1|
# | 0| 9.7|10.1|18.2|
# | 0| 12.8|10.3|16.1|
# | 0| 9.8|10.5|15.9|
# +---+------------------+----+----+
If you use Spark 2.3 this should also work
from pyspark.sql.functions import pandas_udf, spark_partition_id
from pyspark.sql.functions import PandasUDFType
@pandas_udf(sdf.schema, functionType=PandasUDFType.GROUPED_MAP)
def func(pdf):
pdf += 10
return pdf
sdf.groupBy(spark_partition_id().alias("_pid")).apply(func)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.