[英]I'm new to spark,trying to generate the decision tree model in scala and use that model in java to predict.How to load that model in java?
[英]Use saved Spark mllib decision tree binary classification model to predict on new data
我正在使用Spark版本2.2.0和Scala版本2.11.8。 我使用以下代碼創建並保存了決策樹二進制分類模型:
package...
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
object DecisionTreeClassification {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder
.master("local[*]")
.appName("Decision Tree")
.getOrCreate()
// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sparkSession.sparkContext, "path/to/file/xyz.txt")
// Split the data into training and test sets (20% held out for testing)
val splits = data.randomSplit(Array(0.8, 0.2))
val (trainingData, testData) = (splits(0), splits(1))
// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println(s"Test Error = $testErr")
println(s"Learned classification tree model:\n ${model.toDebugString}")
// Save and load model
model.save(sparkSession.sparkContext, "target/tmp/myDecisionTreeClassificationModel")
val sameModel = DecisionTreeModel.load(sparkSession.sparkContext, "target/tmp/myDecisionTreeClassificationModel")
// $example off$
sparkSession.sparkContext.stop()
}
}
現在,我想使用此保存的模型預測新數據的標簽(0或1)。 我是Spark的新手,有人可以讓我知道該怎么做嗎?
我找到了該問題的答案,所以我認為如果有人正在尋找類似問題的答案,我應該分享它
要預測新數據,只需在停止spark會話之前添加幾行:
val newData = MLUtils.loadLibSVMFile(sparkSession.sparkContext, "path/to/file/abc.txt")
val newDataPredictions = newData.map
{ point =>
val newPrediction = model.predict(point.features)
(point.label, newPrediction)
}
newDataPredictions.foreach(f => println("Predicted label", f._2))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.