简体   繁体   中英

Multiclass classification, show raw predictions better in Scala with Spark

Working with the Iris dataset (LogisticRegressionWithLBFGS(), multiclass classification). I pulled my data into an rdd, converted to a Dataframe, done some tidying up on it. Created a labelindex on the Iris plant class/label field. Created a feature vector of the other fields. Took these two fields of a dataframe and converted into a labelpoint rdd instance, where I can feed the data into LogisticRegressionWithLBFGS().

Here is some predictor code:

val model = new LogisticRegressionWithLBFGS()
  .setNumClasses(10)
  .setIntercept(true)
  .setValidateData(true)
  .run(training)

Scores and labels:

val scoreAndLabels_ofTrain = training.map {
  point =>
    val score = model.predict(point.features)
    (score, point.label)
}

I wanted to see the predictions

scoreAndLabels_ofTrain.take(200).foreach(println)

The only problem is, I got this example from a book, pretty much. I was kind hoping to see a dataset, that shows the feature columns, what the predicted number was, what probability score it gave, etc I'd imagine I'd need to do a conversion of the labelindex, if i wanted to see the string data they represent.

How do I get better looking, tabular data as close as possible to the original dataset, with predictions against them? I think i'm missing a trick here somewhere.

The output to above looks like:

(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
...

What does this even mean? Not sure how to read/interpret the data For the first line,is it saying, it predicted "2.0", and the actual label was "2.0"? Am I understanding it right?

Yes, what you have is the (Label,Prediction) in form of a RDD[(Double, Double)] when you apply the map to the input dataset and make the prediction for each element. But, you are using the mlib LR implementation. You can use directly the Dataframe implementation. Take a look to the example . The fit function optimizes the model and return a LogisticRagressionModel . Apply the transform method to your input Dataframe and a new column with the prediction will be added.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM