如何使用apache spark的MLlib的线性回归?

我是apache spark的新手,从MLlib的文档中,我发现了一个scala的例子,但我真的不知道scala,有人知道java中的一个例子吗? 谢谢! 示例代码是

import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.regression.LabeledPoint // Load and parse the data val data = sc.textFile("mllib/data/ridge-data/lpsa.data") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray) } // Building the model val numIterations = 20 val model = LinearRegressionWithSGD.train(parsedData, numIterations) // Evaluate model on training examples and compute training error val valuesAndPreds = parsedData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count println("training Mean Squared Error = " + MSE) 

来自MLlib的文件谢谢!

如文档中所示:

MLlib的所有方法都使用Java友好类型,因此您可以像在Scala中一样导入和调用它们。 唯一需要注意的是,这些方法使用Scala RDD对象,而Spark Java API使用单独的JavaRDD类。 您可以通过在JavaRDD对象上调用.rdd()将Java RDD转换为Scala。

这并不容易,因为你仍然需要在java中重现scala代码,但它可以工作(至少在这种情况下)。

话虽如此,这是一个java实现:

 public void linReg() { String master = "local"; SparkConf conf = new SparkConf().setAppName("csvParser").setMaster( master); JavaSparkContext sc = new JavaSparkContext(conf); JavaRDD data = sc.textFile("mllib/data/ridge-data/lpsa.data"); JavaRDD parseddata = data .map(new Function() { // I see no ways of just using a lambda, hence more verbosity than with scala @Override public LabeledPoint call(String line) throws Exception { String[] parts = line.split(","); String[] pointsStr = parts[1].split(" "); double[] points = new double[pointsStr.length]; for (int i = 0; i < pointsStr.length; i++) points[i] = Double.valueOf(pointsStr[i]); return new LabeledPoint(Double.valueOf(parts[0]), Vectors.dense(points)); } }); // Building the model int numIterations = 20; LinearRegressionModel model = LinearRegressionWithSGD.train( parseddata.rdd(), numIterations); // notice the .rdd() // Evaluate model on training examples and compute training error JavaRDD> valuesAndPred = parseddata .map(point -> new Tuple2(point.label(), model .predict(point.features()))); // important point here is the Tuple2 explicit creation. double MSE = valuesAndPred.mapToDouble( tuple -> Math.pow(tuple._1 - tuple._2, 2)).mean(); // you can compute the mean with this function, which is much easier System.out.println("training Mean Squared Error = " + String.valueOf(MSE)); } 

它远非完美,但我希望它能让您更好地理解如何在Mllib文档中使用scala示例。

 package org.apache.spark.examples; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import java.io.Serializable; import java.util.Arrays; import java.util.Random; import java.util.regex.Pattern; /** * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); static void showWarning() { String warning = "WARN: This is a naive implementation of Logistic Regression " + "and is given as an example!\n" + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + "for more conventional use."; System.err.println(warning); } static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; this.y = y; } double[] x; double y; } static class ParsePoint implements Function { private static final Pattern SPACE = Pattern.compile(" "); @Override public DataPoint call(String line) { String[] tok = SPACE.split(line); double y = Double.parseDouble(tok[0]); double[] x = new double[D]; for (int i = 0; i < D; i++) { x[i] = Double.parseDouble(tok[i + 1]); } return new DataPoint(x, y); } } static class VectorSum implements Function2 { @Override public double[] call(double[] a, double[] b) { double[] result = new double[D]; for (int j = 0; j < D; j++) { result[j] = a[j] + b[j]; } return result; } } static class ComputeGradient implements Function { private final double[] weights; ComputeGradient(double[] weights) { this.weights = weights; } @Override public double[] call(DataPoint p) { double[] gradient = new double[D]; for (int i = 0; i < D; i++) { double dot = dot(weights, px); gradient[i] = (1 / (1 + Math.exp(-py * dot)) - 1) * py * px[i]; } return gradient; } } public static double dot(double[] a, double[] b) { double x = 0; for (int i = 0; i < D; i++) { x += a[i] * b[i]; } return x; } public static void printWeights(double[] a) { System.out.println(Arrays.toString(a)); } public static void main(String[] args) { if (args.length < 2) { System.err.println("Usage: JavaHdfsLR  "); System.exit(1); } showWarning(); SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(args[0]); JavaRDD points = lines.map(new ParsePoint()).cache(); int ITERATIONS = Integer.parseInt(args[1]); // Initialize w to a random value double[] w = new double[D]; for (int i = 0; i < D; i++) { w[i] = 2 * rand.nextDouble() - 1; } System.out.print("Initial w: "); printWeights(w); for (int i = 1; i <= ITERATIONS; i++) { System.out.println("On iteration " + i); double[] gradient = points.map( new ComputeGradient(w) ).reduce(new VectorSum()); for (int j = 0; j < D; j++) { w[j] -= gradient[j]; } } System.out.print("Final w: "); printWeights(w); sc.stop(); } }