I have a spark DataFrame
with many float columns after reading in a CSV file.
I want to combine all the float columns into one ArrayType(FloatType())
.
Any ideas how to do this with PySpark (or Scala)?
If you know all the float column name. You can try this (scala)
val names = Seq("float_col1", "float_col2","float_col3"...."float_col10");
df.withColumn("combined", array(names.map(frame(_)):_*))
Here is another version in Scala:
data.printSchema
root
|-- Int_Col1: integer (nullable = false)
|-- Str_Col1: string (nullable = true)
|-- Float_Col1: float (nullable = false)
|-- Float_Col2: float (nullable = false)
|-- Str_Col2: string (nullable = true)
|-- Float_Col3: float (nullable = false)
data.show()
+--------+--------+----------+----------+--------+----------+
|Int_Col1|Str_Col1|Float_Col1|Float_Col2|Str_Col2|Float_Col3|
+--------+--------+----------+----------+--------+----------+
| 1| ABC| 10.99| 20.99| a| 9.99|
| 2| XYZ| 999.1343| 9858.1| b| 488.99|
+--------+--------+----------+----------+--------+----------+
Add a new array<float>
field to concatenate all float
values.
val df = data.withColumn("Float_Arr_Col",array().cast("array<float>"))
Then filter the datatype that is needed and concatenate the float columns using foldLeft
df.dtypes
.collect{ case (dn, dt) if dt.startsWith("FloatType") => dn }
.foldLeft(df)((accDF, c) => accDF.withColumn("Float_Arr_Col",
array_union(col("Float_Arr_Col"),array(col(c)))))
.show(false)
Output:
+--------+--------+----------+----------+--------+----------+--------------------------+
|Int_Col1|Str_Col1|Float_Col1|Float_Col2|Str_Col2|Float_Col3|Float_Arr_Col |
+--------+--------+----------+----------+--------+----------+--------------------------+
|1 |ABC |10.99 |20.99 |a |9.99 |[10.99, 20.99, 9.99] |
|2 |XYZ |999.1343 |9858.1 |b |488.99 |[999.1343, 9858.1, 488.99]|
+--------+--------+----------+----------+--------+----------+--------------------------+
Hope this helps!
Found the solution. Very straightforward, but hard to find.
float_cols = ['_c1', '_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10']
df.withColumn('combined', array([col(c) for c in float_cols]))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.