![](/img/trans.png)
[英]how to do math operations on a pandas columns and save it as a new dataframe
[英]How to do vector operations on pyspark dataframe columns?
我有一個 dataframe 像這樣:
id | vector1 | id2 | vector2
其中 id 是整數,向量是 SparseVector 類型。
對於每一行,我想添加一個余弦相似度的列,這將由vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
,但我可以'不知道如何使用它來將它放入一個新列。我試過制作一個 udf,但似乎無法弄清楚
使用 scala 的解決方案
There is a utility object org.apache.spark.ml.linalg.BLAS inside spark repo which uses com.github.fommil.netlib.BLAS
to do dot product. 但是 object 對於 spark 內部提交者來說是 package 私有的,要在這里使用它,我們需要在當前項目中復制該實用程序,如下所示 -
package utils
import com.github.fommil.netlib.{F2jBLAS, BLAS => NetlibBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
/**
* Utility object org.apache.spark.ml.linalg.BLAS is package private in spark repo,
* copying it here org.apache.spark.ml.linalg.BLAS to use the utility
* BLAS routines for MLlib's vectors and matrices.
*/
object BLAS extends Serializable {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
// For level-1 routines, we use Java implementation.
private def f2jBLAS: NetlibBLAS = {
if (_f2jBLAS == null) {
_f2jBLAS = new F2jBLAS
}
_f2jBLAS
}
/**
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
require(x.size == y.size,
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
" x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)
case (sx: SparseVector, dy: DenseVector) =>
dot(sx, dy)
case (dx: DenseVector, sy: SparseVector) =>
dot(sy, dx)
case (sx: SparseVector, sy: SparseVector) =>
dot(sx, sy)
case _ =>
throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
}
}
/**
* dot(x, y)
*/
private def dot(x: DenseVector, y: DenseVector): Double = {
val n = x.size
f2jBLAS.ddot(n, x.values, 1, y.values, 1)
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.length
var sum = 0.0
var k = 0
while (k < nnz) {
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
}
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val yIndices = y.indices
val nnzx = xIndices.length
val nnzy = yIndices.length
var kx = 0
var ky = 0
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
if (ky < nnzy && yIndices(ky) == ix) {
sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
}
sum
}
}
val df = Seq(
(0, Vectors.dense(0.0, 10.0, 0.5), 1, Vectors.dense(0.0, 10.0, 0.5)),
(1, Vectors.dense(0.0, 10.0, 0.2), 2, Vectors.dense(0.0, 10.0, 0.2))
).toDF("id", "vector1", "id2", "vector2")
df.show(false)
df.printSchema()
/**
* +---+--------------+---+--------------+
* |id |vector1 |id2|vector2 |
* +---+--------------+---+--------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|
* +---+--------------+---+--------------+
*
* root
* |-- id: integer (nullable = false)
* |-- vector1: vector (nullable = true)
* |-- id2: integer (nullable = false)
* |-- vector2: vector (nullable = true)
*/
// vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
val cosine_similarity = udf((vector1: Vector, vector2: Vector) => utils.BLAS.dot(vector1, vector2) /
(Math.sqrt(utils.BLAS.dot(vector1, vector1))* Math.sqrt(utils.BLAS.dot(vector2, vector2)))
)
df.withColumn("cosine", cosine_similarity($"vector1", $"vector2"))
.show(false)
/**
* +---+--------------+---+--------------+------------------+
* |id |vector1 |id2|vector2 |cosine |
* +---+--------------+---+--------------+------------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|0.9999999999999999|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|1.0000000000000002|
* +---+--------------+---+--------------+------------------+
*/
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.