簡體   English   中英

如何從Spark MLlib的JavaDecsionTreeRegressionExample.java獲取預測值?

[英]How to get predicted values from JavaDecsionTreeRegressionExample.java of Spark MLlib?

我想從JavaDecisionTreeRegressionExample.java獲取預測值,但不僅要獲取決策樹和MAE和RMSE等指標的描述。 有誰知道該怎么做,或者我可以使用哪種方法來獲得預測值?

我嘗試了許多方法(由RegressionEvaluator和DecisionTreeRegressionModel類提供)來解決此問題,但我仍然不知道如何獲得它們。 因此,如果有人知道該怎么做,請告訴我。 非常感謝你!

以下是JavaDecisionTreeRegressionExample.java的源代碼

package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
// $example off$

public class JavaDecisionTreeRegressionExample {
  public static void main(String[] args) {
    SparkSession spark = SparkSession
      .builder()
      .appName("JavaDecisionTreeRegressionExample")
      .getOrCreate();
    // $example on$
    // Load the data stored in LIBSVM format as a DataFrame.
    Dataset<Row> data = spark.read().format("libsvm")
      .load("data/mllib/sample_libsvm_data.txt");

    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    VectorIndexerModel featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data);

    // Split the data into training and test sets (30% held out for testing).
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressor dt = new DecisionTreeRegressor()
      .setFeaturesCol("indexedFeatures");

    // Chain indexer and tree in a Pipeline.
    Pipeline pipeline = new Pipeline()
      .setStages(new PipelineStage[]{featureIndexer, dt});

    // Train model. This also runs the indexer.
    PipelineModel model = pipeline.fit(trainingData);

    // Make predictions.
    Dataset<Row> predictions = model.transform(testData);

    // Select example rows to display.
    predictions.select("label", "features").show(5);

    // Select (prediction, true label) and compute test error.
    RegressionEvaluator evaluator = new RegressionEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("rmse");
    double rmse = evaluator.evaluate(predictions);
    System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

    DecisionTreeRegressionModel treeModel =
      (DecisionTreeRegressionModel) (model.stages()[1]);
    System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
    // $example off$

    spark.stop();
  }
}

我解決了我的問題。 Modify predictions.select("label", "features").show(5); predictions.select("prediction","label", "features").show(5); 然后,您可以獲得預測值。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM