繁体   English   中英

spark dataframe(scala)中tf idf output的余弦相似度

[英]Cosine similarity over tf idf output in spark dataframe (scala)

我正在使用 Spark Scala 来计算 Dataframe 行之间的余弦相似度。

Dataframe 格式如下:

root
 |-- id: long (nullable = true)
 |-- features: vector (nullable = true)

dataframe 样品如下:

+---+--------------------+
| id|            features|
+---+--------------------+
| 65|(10000,[48,70,87,...|
|191|(10000,[1,73,77,1...|
+---+--------------------+

给我结果的代码如下:

val df = spark.read.json("articles_line.json")
val tokenizer = new Tokenizer().setInputCol("desc").setOutputCol("words")
val wordsDF = tokenizer.transform(df)

def flattenWords = udf( (s: Seq[Seq[String]]) => s.flatMap(identity) )
val groupedDF = wordsDF.groupBy("id").
  agg(flattenWords(collect_list("words")).as("grouped_words"))
val hashingTF = new HashingTF().
  setInputCol("grouped_words").setOutputCol("rawFeatures").setNumFeatures(10000)
val featurizedData = hashingTF.transform(groupedDF)
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData)
val asDense = udf((v: Vector) => v.toDense) //transform to dense matrix
val newDf = rescaledData.select('id, 'features)
    .withColumn("dense_features", asDense($"features")

最终 dataframe 看起来像

+-----+--------------------+--------------------+
|   id|            features|      dense_features|
+-----+--------------------+--------------------+
|21209|(10000,[128,288,2...|[0.0,0.0,0.0,0.0,...|
|21223|(10000,[8,18,32,4...|[0.0,0.0,0.0,0.0,...|
+-----+--------------------+--------------------+

我不明白如何处理“dense_features”来计算余弦相似度。 这篇文章对我不起作用。 感谢任何帮助。

一行dense_features 的示例。 为简单起见,剪掉了长度。

[[0.0,0.0,0.0,0.0,7.08,0.0,0.0,0.0,0.0,2.24,0.0,0.0,0.0,0.0,0.0,,9.59]]

这对我来说很好。 完整代码

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.types._
import org.apache.spark.ml.feature._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val df = spark.read.json("/user/dmitry.korniltsev/lab02/data/DO_record_per_line.json")
val cleaned_df = df
    .withColumn("desc", regexp_replace('desc, "[^\\w\\sа-яА-ЯЁё]", ""))
    .withColumn("desc", lower(trim(regexp_replace('desc, "\\s+", " "))))
    .where(length('desc) > 0)

val tokenizer = new Tokenizer().setInputCol("desc").setOutputCol("words")
val wordsDF = tokenizer.transform(cleaned_df)
def flattenWords = udf( (s: Seq[Seq[String]]) => s.flatMap(identity) )
val hashingTF = new HashingTF()
    .setInputCol("words")
    .setOutputCol("rawFeatures")
    .setNumFeatures(20000)
val featurizedData = hashingTF.transform(wordsDF)
val idf = new IDF()
    .setInputCol("rawFeatures")
    .setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData)
val asDense = udf((v: Vector) => v.toDense)
val newDf = rescaledData
    .withColumn("dense_features", asDense($"features"))

val cosSimilarity = udf { (x: Vector, y: Vector) => 
    val v1 = x.toArray
    val v2 = y.toArray
    val l1 = scala.math.sqrt(v1.map(x => x*x).sum)
    val l2 = scala.math.sqrt(v2.map(x => x*x).sum)
    val scalar = v1.zip(v2).map(p => p._1*p._2).sum
    scalar/(l1*l2)
    }

val id_list = Seq(23325, 15072, 24506, 3879, 1067, 17019)
val filtered_df = newDf
    .filter(col("id").isin(id_list: _*))
    .select('id.alias("id_frd"), 'dense_features.alias("dense_frd"), 'lang.alias("lang_frd"))

val joinedDf = newDf.join(broadcast(filtered_df), 'id =!= 'id_frd && 'lang === 'lang_frd)
    .withColumn("cosine_sim", cosSimilarity(col("dense_frd"), col("dense_features")))

val filtered = joinedDf
    .filter(col("lang")==="en")
    .withColumn("cosine_sim", when(col("cosine_sim").isNaN, 0).otherwise(col("cosine_sim")))
    .withColumn("rank", row_number().over(
            Window.partitionBy(col("id_frd")).orderBy(col("cosine_sim").desc)))
    .filter(col("rank")between(2,11))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM