简体   繁体   中英

PySpark: How to write a Spark dataframe having a column with type SparseVector into CSV file?

I have a spark dataframe which has one column with type spark.mllib.linalg.SparseVector:

1) how can I write it into a csv file?

2) how can I print all the vectors?

To write the dataframe to a csv file you can use the standard df.write.csv(output_path) .

However, if you just use the above you are likely to get the java.lang.UnsupportedOperationException: CSV data source does not support struct<type:tinyint,size:int,indices:array<int>,values:array<double>> data type error for the column with the SparseVector type.

There are two ways to print the SparseVector and avoid that error: the sparse format or the dense format.

If you want to print in the dence format, you can define udf like this:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.sql.functions import col

dense_format_udf = udf(lambda x: ','.join([str(elem) for elem in x], StringType())

df = df.withColumn('column_name', dense_format_udf(col('column_name')))

df.write.option("delimiter", "\t").csv(output_path)

The column outputs to something like this in the dense format: 1.0,0.0,5.0,0.0

If you want to print in the sparse format, you can utilize the OOB __str__ function of the SparseVector class , or be creative and define your own output format. Here I am going to use the OOB function.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.sql.functions import col

sparse_format_udf = udf(lambda x: str(x), StringType())

df = df.withColumn('column_name', sparse_format_udf(col('column_name')))

df.write.option("delimiter", "\t").csv(output_path)

The column prints to something like this in the sparse format (4,[0,2],[1.0,5.0])

Note I have tried this approach before: df = df.withColumn("column_name", col("column_name").cast("string")) but the column just prints to something like this [0,5,org.apache.spark.sql.catalyst.expressions.UnsafeArrayData@6988050,org.apache.spark.sql.catalyst.expressions.UnsafeArrayData@ec4ae6ab] which is not desirable.

  1. https://github.com/databricks/spark-csv
  2. df2 = df1.map(lambda row: row.yourVectorCol)

    OR df1.map(lambda row: row[1])

    where you either have a named column or just refer to the column by its position in the row.

    Then, to print it, you can df2.collect()

Without more information, this may be helpful to you, or not helpful enough to you. Please elaborate a bit.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM