簡體   English   中英

如何使用Apache Spark MLlib的線性回歸?

[英]how to use the linear regression of MLlib of apache spark?

我是Apache Spark的新手,從MLlib的文檔中,我找到了scala的示例,但是我真的不知道scala,有人知道Java的示例嗎? 謝謝! 示例代碼是

import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint

// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
  val parts = line.split(',')
  LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray)
}

// Building the model
val numIterations = 20
val model = LinearRegressionWithSGD.train(parsedData, numIterations)

// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ +     _)/valuesAndPreds.count
println("training Mean Squared Error = " + MSE)

MLlib文件感謝!

如文檔中所示:

MLlib的所有方法都使用Java友好類型,因此您可以像在Scala中一樣進行導入和調用。 唯一需要注意的是,這些方法采用Scala RDD對象,而Spark Java API使用單獨的JavaRDD類。 您可以通過在JavaRDD對象上調用.rdd()將Java RDD轉換為Scala。

這並不容易,因為您仍然必須在Java中重現scala代碼,但是它可以工作(至少在這種情況下)。

話雖如此,這是一個Java實現:

public void linReg() {
    String master = "local";
    SparkConf conf = new SparkConf().setAppName("csvParser").setMaster(
            master);
    JavaSparkContext sc = new JavaSparkContext(conf);
    JavaRDD<String> data = sc.textFile("mllib/data/ridge-data/lpsa.data");
    JavaRDD<LabeledPoint> parseddata = data
            .map(new Function<String, LabeledPoint>() {
            // I see no ways of just using a lambda, hence more verbosity than with scala
                @Override
                public LabeledPoint call(String line) throws Exception {
                    String[] parts = line.split(",");
                    String[] pointsStr = parts[1].split(" ");
                    double[] points = new double[pointsStr.length];
                    for (int i = 0; i < pointsStr.length; i++)
                        points[i] = Double.valueOf(pointsStr[i]);
                    return new LabeledPoint(Double.valueOf(parts[0]),
                            Vectors.dense(points));
                }
            });

    // Building the model
    int numIterations = 20;
    LinearRegressionModel model = LinearRegressionWithSGD.train(
    parseddata.rdd(), numIterations); // notice the .rdd()

    // Evaluate model on training examples and compute training error
    JavaRDD<Tuple2<Double, Double>> valuesAndPred = parseddata
            .map(point -> new Tuple2<Double, Double>(point.label(), model
                    .predict(point.features())));
    // important point here is the Tuple2 explicit creation.

    double MSE = valuesAndPred.mapToDouble(
            tuple -> Math.pow(tuple._1 - tuple._2, 2)).mean();
    // you can compute the mean with this function, which is much easier
    System.out.println("training Mean Squared Error = "
            + String.valueOf(MSE));
}

它遠非完美,但我希望它能使您更好地理解如何在Mllib文檔中使用scala示例。

package org.apache.spark.examples;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;
import java.util.regex.Pattern;

/**
 * Logistic regression based classification.
 *
 * This is an example implementation for learning how to use Spark. For more conventional use,
 * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
 * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
 */
public final class JavaHdfsLR {

  private static final int D = 10;   // Number of dimensions
  private static final Random rand = new Random(42);

  static void showWarning() {
    String warning = "WARN: This is a naive implementation of Logistic Regression " +
            "and is given as an example!\n" +
            "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
            "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
            "for more conventional use.";
    System.err.println(warning);
  }

  static class DataPoint implements Serializable {
    DataPoint(double[] x, double y) {
      this.x = x;
      this.y = y;
    }

    double[] x;
    double y;
  }

  static class ParsePoint implements Function<String, DataPoint> {
    private static final Pattern SPACE = Pattern.compile(" ");

    @Override
    public DataPoint call(String line) {
      String[] tok = SPACE.split(line);
      double y = Double.parseDouble(tok[0]);
      double[] x = new double[D];
      for (int i = 0; i < D; i++) {
        x[i] = Double.parseDouble(tok[i + 1]);
      }
      return new DataPoint(x, y);
    }
  }

  static class VectorSum implements Function2<double[], double[], double[]> {
    @Override
    public double[] call(double[] a, double[] b) {
      double[] result = new double[D];
      for (int j = 0; j < D; j++) {
        result[j] = a[j] + b[j];
      }
      return result;
    }
  }

  static class ComputeGradient implements Function<DataPoint, double[]> {
    private final double[] weights;

    ComputeGradient(double[] weights) {
      this.weights = weights;
    }

    @Override
    public double[] call(DataPoint p) {
      double[] gradient = new double[D];
      for (int i = 0; i < D; i++) {
        double dot = dot(weights, p.x);
        gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
      }
      return gradient;
    }
  }

  public static double dot(double[] a, double[] b) {
    double x = 0;
    for (int i = 0; i < D; i++) {
      x += a[i] * b[i];
    }
    return x;
  }

  public static void printWeights(double[] a) {
    System.out.println(Arrays.toString(a));
  }

  public static void main(String[] args) {

    if (args.length < 2) {
      System.err.println("Usage: JavaHdfsLR <file> <iters>");
      System.exit(1);
    }

    showWarning();

    SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    JavaRDD<String> lines = sc.textFile(args[0]);
    JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
    int ITERATIONS = Integer.parseInt(args[1]);

    // Initialize w to a random value
    double[] w = new double[D];
    for (int i = 0; i < D; i++) {
      w[i] = 2 * rand.nextDouble() - 1;
    }

    System.out.print("Initial w: ");
    printWeights(w);

    for (int i = 1; i <= ITERATIONS; i++) {
      System.out.println("On iteration " + i);

      double[] gradient = points.map(
        new ComputeGradient(w)
      ).reduce(new VectorSum());

      for (int j = 0; j < D; j++) {
        w[j] -= gradient[j];
      }

    }

    System.out.print("Final w: ");
    printWeights(w);
    sc.stop();
  }
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM