簡體   English   中英

如何對 pyspark dataframe 列進行矢量運算?

[英]How to do vector operations on pyspark dataframe columns?

我有一個 dataframe 像這樣:

id | vector1 | id2 | vector2

其中 id 是整數,向量是 SparseVector 類型。

對於每一行,我想添加一個余弦相似度的列,這將由vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2)) ,但我可以'不知道如何使用它來將它放入一個新列。我試過制作一個 udf,但似乎無法弄清楚

使用 scala 的解決方案

There is a utility object org.apache.spark.ml.linalg.BLAS inside spark repo which uses com.github.fommil.netlib.BLAS to do dot product. 但是 object 對於 spark 內部提交者來說是 package 私有的,要在這里使用它,我們需要在當前項目中復制該實用程序,如下所示 -

package utils

import com.github.fommil.netlib.{F2jBLAS, BLAS => NetlibBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}

/**
  * Utility object org.apache.spark.ml.linalg.BLAS is package private in spark repo,
  * copying it here org.apache.spark.ml.linalg.BLAS to use the utility
  * BLAS routines for MLlib's vectors and matrices.
  */
object BLAS extends Serializable {

  @transient private var _f2jBLAS: NetlibBLAS = _
  @transient private var _nativeBLAS: NetlibBLAS = _

  // For level-1 routines, we use Java implementation.
  private def f2jBLAS: NetlibBLAS = {
    if (_f2jBLAS == null) {
      _f2jBLAS = new F2jBLAS
    }
    _f2jBLAS
  }

  /**
    * dot(x, y)
    */
  def dot(x: Vector, y: Vector): Double = {
    require(x.size == y.size,
      "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
        " x.size = " + x.size + ", y.size = " + y.size)
    (x, y) match {
      case (dx: DenseVector, dy: DenseVector) =>
        dot(dx, dy)
      case (sx: SparseVector, dy: DenseVector) =>
        dot(sx, dy)
      case (dx: DenseVector, sy: SparseVector) =>
        dot(sy, dx)
      case (sx: SparseVector, sy: SparseVector) =>
        dot(sx, sy)
      case _ =>
        throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
    }
  }

  /**
    * dot(x, y)
    */
  private def dot(x: DenseVector, y: DenseVector): Double = {
    val n = x.size
    f2jBLAS.ddot(n, x.values, 1, y.values, 1)
  }

  /**
    * dot(x, y)
    */
  private def dot(x: SparseVector, y: DenseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val nnz = xIndices.length

    var sum = 0.0
    var k = 0
    while (k < nnz) {
      sum += xValues(k) * yValues(xIndices(k))
      k += 1
    }
    sum
  }

  /**
    * dot(x, y)
    */
  private def dot(x: SparseVector, y: SparseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val yIndices = y.indices
    val nnzx = xIndices.length
    val nnzy = yIndices.length

    var kx = 0
    var ky = 0
    var sum = 0.0
    // y catching x
    while (kx < nnzx && ky < nnzy) {
      val ix = xIndices(kx)
      while (ky < nnzy && yIndices(ky) < ix) {
        ky += 1
      }
      if (ky < nnzy && yIndices(ky) == ix) {
        sum += xValues(kx) * yValues(ky)
        ky += 1
      }
      kx += 1
    }
    sum
  }
}

使用上述實用程序查找余弦相似度

  val df = Seq(
      (0, Vectors.dense(0.0, 10.0, 0.5), 1, Vectors.dense(0.0, 10.0, 0.5)),
      (1, Vectors.dense(0.0, 10.0, 0.2), 2, Vectors.dense(0.0, 10.0, 0.2))
    ).toDF("id", "vector1", "id2",  "vector2")
    df.show(false)
    df.printSchema()
    /**
      * +---+--------------+---+--------------+
      * |id |vector1       |id2|vector2       |
      * +---+--------------+---+--------------+
      * |0  |[0.0,10.0,0.5]|1  |[0.0,10.0,0.5]|
      * |1  |[0.0,10.0,0.2]|2  |[0.0,10.0,0.2]|
      * +---+--------------+---+--------------+
      *
      * root
      * |-- id: integer (nullable = false)
      * |-- vector1: vector (nullable = true)
      * |-- id2: integer (nullable = false)
      * |-- vector2: vector (nullable = true)
      */

    // vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
    val cosine_similarity = udf((vector1: Vector, vector2: Vector) => utils.BLAS.dot(vector1, vector2) /
        (Math.sqrt(utils.BLAS.dot(vector1, vector1))* Math.sqrt(utils.BLAS.dot(vector2, vector2)))
    )
    df.withColumn("cosine", cosine_similarity($"vector1", $"vector2"))
      .show(false)

    /**
      * +---+--------------+---+--------------+------------------+
      * |id |vector1       |id2|vector2       |cosine            |
      * +---+--------------+---+--------------+------------------+
      * |0  |[0.0,10.0,0.5]|1  |[0.0,10.0,0.5]|0.9999999999999999|
      * |1  |[0.0,10.0,0.2]|2  |[0.0,10.0,0.2]|1.0000000000000002|
      * +---+--------------+---+--------------+------------------+
      */

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM