简体   繁体   中英

Create a dataframe with SparseVector PySpark

Let's say I have a Spark dataframe that looks like this

Row(Y=a, X1=3.2, X2=4.5)

What I'd want is:

Row(Y=a, features=SparseVector(2, {X1: 3.2, X2: 4.5})

Perhaps this is helpful-

Written in scala but can be implemented in pyspark with minimal change

VectorAssembler to create vector from input columns

val df = spark.sql("select 'a' as Y, 3.2 as X1, 4.5 as X2")
    df.show(false)
    df.printSchema()

    /**
      * +---+---+---+
      * |Y  |X1 |X2 |
      * +---+---+---+
      * |a  |3.2|4.5|
      * +---+---+---+
      *
      * root
      * |-- Y: string (nullable = false)
      * |-- X1: decimal(2,1) (nullable = false)
      * |-- X2: decimal(2,1) (nullable = false)
      */
    import org.apache.spark.ml.feature.VectorAssembler
    val features = new VectorAssembler()
      .setInputCols(Array("X1", "X2"))
      .setOutputCol("features")
      .transform(df)
    features.show(false)
    features.printSchema()

    /**
      * +---+---+---+---------+
      * |Y  |X1 |X2 |features |
      * +---+---+---+---------+
      * |a  |3.2|4.5|[3.2,4.5]|
      * +---+---+---+---------+
      *
      * root
      * |-- Y: string (nullable = false)
      * |-- X1: decimal(2,1) (nullable = false)
      * |-- X2: decimal(2,1) (nullable = false)
      * |-- features: vector (nullable = true)
      */

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM