After running a logistic regression algorithm on a dataset (n = 100 000), I would like to get a correlation matrix of the features.
Here is a preview of my data:
results.columns
res16: Array[String] = Array(label, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked, SexIndex, EmbarkIndex, SexVec, EmbarkVec, features, rawPrediction, probability, prediction)
scala> val fts = results.select("features")
res19: org.apache.spark.sql.DataFrame = [features: vector]
scala> results.select("features").show(10)
+--------------------+
| features|
+--------------------+
|[1.0,1.0,19.0,1.0...|
|[1.0,1.0,19.0,3.0...|
|[1.0,1.0,22.0,0.0...|
|[1.0,1.0,24.0,0.0...|
|[1.0,1.0,30.0,0.0...|
|[1.0,1.0,31.0,0.0...|
|[1.0,1.0,31.0,1.0...|
|[1.0,1.0,36.0,1.0...|
|(8,[0,1,2,6],[1.0...|
|[1.0,1.0,46.0,1.0...|
I know that in RI could use this code in order to get the correlation matrix:
res <- rcorr(as.matrix(my_data))
so I tried something similar with Scala:
val corrMatrix = corr(fts)
and got the following error:
<console>:64: error: overloaded method value corr with alternatives:
(columnName1: String,columnName2: String)org.apache.spark.sql.Column <and>
(column1: org.apache.spark.sql.Column,column2: org.apache.spark.sql.Column)org.apache.spark.sql.Column
cannot be applied to (org.apache.spark.sql.DataFrame)
After looking into this error and reading this and this , I think I need to put these arrays into a DF and then iterate through the DF to find a correlation between each pair, ie something like this pseudocode where (i,j) is an $i-th$ row and the j-th column:
for ( int i = 1, i <= n, i ++){
for( int j = i, <= n, j ++ ){
if( i == j) a(i)(j) = 1
else a(i)(j) = a(j)(i) = corr(i,j) // symmetric matrix
}
}
I am a complete beginner in Scala and Spark so I would really appreciate if someone could help me out.
In Spark 2.0 or later you can:
import org.apache.spark.ml.linalg._
import org.apache.spark.sql.functions._
import spark.implicits._
val n: Int = ??? // number of features
val as_array = udf((v: Vector) => v.toArray)
val corrs = (0 to n).combinations(2).map {
case Seq(i, j) => corr($"vs".getItem(i), $"vs".getItem(j))
}.toSeq
df.select(as_array($"features").alias("vs")).select(corrs: _*)
You will want to use the MLlib corr
function on a RDD[org.apache.spark.mllib.linalg.Vector]
, here is how you get there:
scala> import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
scala> import org.apache.spark.sql.Row
import org.apache.spark.sql.Row
scala> import scala.util.Random
import scala.util.Random
scala> import scala.collection.immutable.{Vector => Std_Vec}
import scala.collection.immutable.{Vector=>Std_Vec}
scala> def randVec(n: Int): Std_Vec[Double] = Seq.fill(n)(Random.nextDouble).toVector
randVec: (n: Int)scala.collection.immutable.Vector[Double]
scala> val myDF = sc.parallelize((0 until 10).map(x => (x.toString, randVec(10)))).toDF("lab", "features")
myDF: org.apache.spark.sql.DataFrame = [lab: string, features: array<double>]
scala> myDF.show
+---+--------------------+
|lab| features|
+---+--------------------+
| 0|[0.81916384524734...|
| 1|[0.22711488489859...|
| 2|[0.52918465208259...|
| 3|[0.29253172322411...|
| 4|[0.22417302941674...|
| 5|[0.21693234260391...|
| 6|[0.39854095726097...|
| 7|[0.58807946374202...|
| 8|[0.96849542301746...|
| 9|[0.93194455754124...|
+---+--------------------+
Statistics.corr
: scala> import spark.implicits._
import spark.implicits._
scala> val featureRDD = myDF.rdd.map{case Row(_, feat: Seq[Double]) => Vectors.dense(feat.toArray)}
featureRDD: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] = MapPartitionsRDD[31] at map at <console>:46
scala> import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.Statistics
scala> Statistics.corr(featureRDD)
res10: org.apache.spark.mllib.linalg.Matrix =
1.0 0.40745998071406825 ... (10 total)
0.40745998071406825 1.0 ...
-0.08774724980258353 -0.40530111151726017 ...
0.01094426191127371 -0.2586807037180266 ...
0.39307374354852526 0.8309967336954839 ...
0.29758193455372256 0.5102634076586834 ...
0.15412639422865976 -0.07047908269724495 ...
-0.34671405612623457 0.13551628442995656 ...
0.296600595616234 -0.16362444756013478 ...
-0.13393787396551504 -0.42967054975951785 ...
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.