简体   繁体   English

如何对 pyspark dataframe 列进行矢量运算?

[英]How to do vector operations on pyspark dataframe columns?

I have a dataframe like so:我有一个 dataframe 像这样:

id | vector1 | id2 | vector2

where the ids are ints and the vectors are SparseVector types.其中 id 是整数,向量是 SparseVector 类型。

For each row, I want to add on a column that is cosine similarity, which would be done by vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2)) but I can't figure out how to use this to put it into a new column. I've tried making a udf, but can't seem to figure it out对于每一行,我想添加一个余弦相似度的列,这将由vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2)) ,但我可以'不知道如何使用它来将它放入一个新列。我试过制作一个 udf,但似乎无法弄清楚

Solution using scala使用 scala 的解决方案

There is a utility object org.apache.spark.ml.linalg.BLAS inside spark repo which uses com.github.fommil.netlib.BLAS to do dot product. There is a utility object org.apache.spark.ml.linalg.BLAS inside spark repo which uses com.github.fommil.netlib.BLAS to do dot product. But that object is package private for spark internal committers, to use it here, we need to copy that utility in the current project as below -但是 object 对于 spark 内部提交者来说是 package 私有的,要在这里使用它,我们需要在当前项目中复制该实用程序,如下所示 -

package utils

import com.github.fommil.netlib.{F2jBLAS, BLAS => NetlibBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}

/**
  * Utility object org.apache.spark.ml.linalg.BLAS is package private in spark repo,
  * copying it here org.apache.spark.ml.linalg.BLAS to use the utility
  * BLAS routines for MLlib's vectors and matrices.
  */
object BLAS extends Serializable {

  @transient private var _f2jBLAS: NetlibBLAS = _
  @transient private var _nativeBLAS: NetlibBLAS = _

  // For level-1 routines, we use Java implementation.
  private def f2jBLAS: NetlibBLAS = {
    if (_f2jBLAS == null) {
      _f2jBLAS = new F2jBLAS
    }
    _f2jBLAS
  }

  /**
    * dot(x, y)
    */
  def dot(x: Vector, y: Vector): Double = {
    require(x.size == y.size,
      "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
        " x.size = " + x.size + ", y.size = " + y.size)
    (x, y) match {
      case (dx: DenseVector, dy: DenseVector) =>
        dot(dx, dy)
      case (sx: SparseVector, dy: DenseVector) =>
        dot(sx, dy)
      case (dx: DenseVector, sy: SparseVector) =>
        dot(sy, dx)
      case (sx: SparseVector, sy: SparseVector) =>
        dot(sx, sy)
      case _ =>
        throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
    }
  }

  /**
    * dot(x, y)
    */
  private def dot(x: DenseVector, y: DenseVector): Double = {
    val n = x.size
    f2jBLAS.ddot(n, x.values, 1, y.values, 1)
  }

  /**
    * dot(x, y)
    */
  private def dot(x: SparseVector, y: DenseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val nnz = xIndices.length

    var sum = 0.0
    var k = 0
    while (k < nnz) {
      sum += xValues(k) * yValues(xIndices(k))
      k += 1
    }
    sum
  }

  /**
    * dot(x, y)
    */
  private def dot(x: SparseVector, y: SparseVector): Double = {
    val xValues = x.values
    val xIndices = x.indices
    val yValues = y.values
    val yIndices = y.indices
    val nnzx = xIndices.length
    val nnzy = yIndices.length

    var kx = 0
    var ky = 0
    var sum = 0.0
    // y catching x
    while (kx < nnzx && ky < nnzy) {
      val ix = xIndices(kx)
      while (ky < nnzy && yIndices(ky) < ix) {
        ky += 1
      }
      if (ky < nnzy && yIndices(ky) == ix) {
        sum += xValues(kx) * yValues(ky)
        ky += 1
      }
      kx += 1
    }
    sum
  }
}

Find cosine similarity using above utilty使用上述实用程序查找余弦相似度

  val df = Seq(
      (0, Vectors.dense(0.0, 10.0, 0.5), 1, Vectors.dense(0.0, 10.0, 0.5)),
      (1, Vectors.dense(0.0, 10.0, 0.2), 2, Vectors.dense(0.0, 10.0, 0.2))
    ).toDF("id", "vector1", "id2",  "vector2")
    df.show(false)
    df.printSchema()
    /**
      * +---+--------------+---+--------------+
      * |id |vector1       |id2|vector2       |
      * +---+--------------+---+--------------+
      * |0  |[0.0,10.0,0.5]|1  |[0.0,10.0,0.5]|
      * |1  |[0.0,10.0,0.2]|2  |[0.0,10.0,0.2]|
      * +---+--------------+---+--------------+
      *
      * root
      * |-- id: integer (nullable = false)
      * |-- vector1: vector (nullable = true)
      * |-- id2: integer (nullable = false)
      * |-- vector2: vector (nullable = true)
      */

    // vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
    val cosine_similarity = udf((vector1: Vector, vector2: Vector) => utils.BLAS.dot(vector1, vector2) /
        (Math.sqrt(utils.BLAS.dot(vector1, vector1))* Math.sqrt(utils.BLAS.dot(vector2, vector2)))
    )
    df.withColumn("cosine", cosine_similarity($"vector1", $"vector2"))
      .show(false)

    /**
      * +---+--------------+---+--------------+------------------+
      * |id |vector1       |id2|vector2       |cosine            |
      * +---+--------------+---+--------------+------------------+
      * |0  |[0.0,10.0,0.5]|1  |[0.0,10.0,0.5]|0.9999999999999999|
      * |1  |[0.0,10.0,0.2]|2  |[0.0,10.0,0.2]|1.0000000000000002|
      * +---+--------------+---+--------------+------------------+
      */

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM