简体   繁体   中英

Scala/Spark - correlation matrix after logistic regression

After running a logistic regression algorithm on a dataset (n = 100 000), I would like to get a correlation matrix of the features.

Here is a preview of my data:

results.columns
res16: Array[String] = Array(label, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked, SexIndex, EmbarkIndex, SexVec, EmbarkVec, features, rawPrediction, probability, prediction)
scala> val fts = results.select("features")
res19: org.apache.spark.sql.DataFrame = [features: vector]

scala> results.select("features").show(10)
+--------------------+
|            features|
+--------------------+
|[1.0,1.0,19.0,1.0...|
|[1.0,1.0,19.0,3.0...|
|[1.0,1.0,22.0,0.0...|
|[1.0,1.0,24.0,0.0...|
|[1.0,1.0,30.0,0.0...|
|[1.0,1.0,31.0,0.0...|
|[1.0,1.0,31.0,1.0...|
|[1.0,1.0,36.0,1.0...|
|(8,[0,1,2,6],[1.0...|
|[1.0,1.0,46.0,1.0...|

I know that in RI could use this code in order to get the correlation matrix:

res <- rcorr(as.matrix(my_data)) 

so I tried something similar with Scala:

val corrMatrix = corr(fts)

and got the following error:

<console>:64: error: overloaded method value corr with alternatives:
  (columnName1: String,columnName2: String)org.apache.spark.sql.Column <and>
  (column1: org.apache.spark.sql.Column,column2: org.apache.spark.sql.Column)org.apache.spark.sql.Column
 cannot be applied to (org.apache.spark.sql.DataFrame)

After looking into this error and reading this and this , I think I need to put these arrays into a DF and then iterate through the DF to find a correlation between each pair, ie something like this pseudocode where (i,j) is an $i-th$ row and the j-th column:

for ( int i = 1, i <= n, i ++){
  for( int j = i, <= n, j ++ ){
    if( i == j) a(i)(j) = 1
    else  a(i)(j) = a(j)(i) = corr(i,j) // symmetric matrix 
    }
}

I am a complete beginner in Scala and Spark so I would really appreciate if someone could help me out.

In Spark 2.0 or later you can:

import org.apache.spark.ml.linalg._
import org.apache.spark.sql.functions._
import spark.implicits._

val n: Int = ??? // number of features

val as_array = udf((v: Vector) => v.toArray)

val corrs = (0 to n).combinations(2).map {
  case Seq(i, j) => corr($"vs".getItem(i), $"vs".getItem(j))
}.toSeq

df.select(as_array($"features").alias("vs")).select(corrs: _*)

You will want to use the MLlib corr function on a RDD[org.apache.spark.mllib.linalg.Vector] , here is how you get there:

Generating some data:

scala> import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{Vector, Vectors}

scala> import org.apache.spark.sql.Row
import org.apache.spark.sql.Row

scala> import scala.util.Random
import scala.util.Random

scala> import scala.collection.immutable.{Vector => Std_Vec}
import scala.collection.immutable.{Vector=>Std_Vec}

scala> def randVec(n: Int): Std_Vec[Double] = Seq.fill(n)(Random.nextDouble).toVector
randVec: (n: Int)scala.collection.immutable.Vector[Double]

scala> val myDF = sc.parallelize((0 until 10).map(x => (x.toString, randVec(10)))).toDF("lab", "features")
myDF: org.apache.spark.sql.DataFrame = [lab: string, features: array<double>]

scala> myDF.show
+---+--------------------+
|lab|            features|
+---+--------------------+
|  0|[0.81916384524734...|
|  1|[0.22711488489859...|
|  2|[0.52918465208259...|
|  3|[0.29253172322411...|
|  4|[0.22417302941674...|
|  5|[0.21693234260391...|
|  6|[0.39854095726097...|
|  7|[0.58807946374202...|
|  8|[0.96849542301746...|
|  9|[0.93194455754124...|
+---+--------------------+

Transforming and running Statistics.corr :

scala> import spark.implicits._
import spark.implicits._

scala> val featureRDD = myDF.rdd.map{case Row(_, feat: Seq[Double]) => Vectors.dense(feat.toArray)}
featureRDD: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] = MapPartitionsRDD[31] at map at <console>:46

scala> import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.Statistics

scala> Statistics.corr(featureRDD)
res10: org.apache.spark.mllib.linalg.Matrix =
1.0                   0.40745998071406825   ... (10 total)
0.40745998071406825   1.0                   ...
-0.08774724980258353  -0.40530111151726017  ...
0.01094426191127371   -0.2586807037180266   ...
0.39307374354852526   0.8309967336954839    ...
0.29758193455372256   0.5102634076586834    ...
0.15412639422865976   -0.07047908269724495  ...
-0.34671405612623457  0.13551628442995656   ...
0.296600595616234     -0.16362444756013478  ...
-0.13393787396551504  -0.42967054975951785  ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM