简体   繁体   English

Spark中带有梯度助推树的多类分类:仅支持二进制分类

[英]Multiclass classification with Gradient Boosting Trees in Spark: only supporting binary classification

While trying to run multi-class classification using Gradient Boosting Trees in Spark mllib. 尝试使用Spark mllib中的Gradient Boosting Trees运行多类分类时。 But it is giving an error "only binary classification is supported". 但是它给出了错误“仅支持二进制分类”。 The dependent variable has 8 levels. 因变量具有8个级别。 The data has 276 columns and 7000 instances. 数据具有276列和7000个实例。

import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.feature.ChiSqSelector

import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils

val data = sc.textFile("data/mllib/train.csv")

val parsedData = data.map { line =>
    val parts = line.split(',').map(_.toDouble)
    LabeledPoint(parts(0), Vectors.dense(parts.tail))
}


// Discretize data in 10 equal bins since ChiSqSelector requires categorical features
// Even though features are doubles, the ChiSqSelector treats each unique value as a category
val discretizedData = parsedData.map { lp =>
  LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 20).floor } ) )
}

// Create ChiSqSelector that will select top 5 features
val selector = new ChiSqSelector(200)

// Create ChiSqSelector model (selecting features)
val transformer = selector.fit(discretizedData)

// Filter the top 5 features from each feature vector
val filteredData = discretizedData.map { lp => 
  LabeledPoint(lp.label, transformer.transform(lp.features)) 
}

//Splitting the data
val splits = filteredData.randomSplit(Array(0.7, 0.3), seed = 11L)
val training = splits(0).cache()
val test = splits(1)


// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 20 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 8
boostingStrategy.treeStrategy.maxDepth = 10
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(training, boostingStrategy)

Error after running the model: 运行模型后出现错误:

java.lang.IllegalArgumentException: requirement failed: Only binary classificati
on is supported for boosting.
        at scala.Predef$.require(Predef.scala:233)
        at org.apache.spark.mllib.tree.configuration.BoostingStrategy.assertVali
d(BoostingStrategy.scala:60)
        at org.apache.spark.mllib.tree.GradientBoostedTrees$.org$apache$spark$ml
lib$tree$GradientBoostedTrees$$boost(GradientBoostedTrees.scala:173)
        at org.apache.spark.mllib.tree.GradientBoostedTrees.run(GradientBoostedT
rees.scala:71)
        at org.apache.spark.mllib.tree.GradientBoostedTrees$.train(GradientBoost
edTrees.scala:143)

Is there any other way this can be done? 还有其他方法可以做到吗?

Unfortunately, at this time, only logistic regression, decision trees, random forests and naive bayes support multiclass classification in spark mllib/ml. 不幸的是,目前,只有逻辑回归,决策树,随机森林和朴素贝叶斯支持spark mllib / ml中的多类分类。

So, I'd suggest changing classification methods. 因此,我建议更改分类方法。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM