[英]Spark Multiclass Classification Example
Do you guys know where can I find examples of multiclass classification in Spark. 你们知道我在哪里可以找到Spark中的多类分类示例。 I spent a lot of time searching in books and in the web, and so far I just know that it is possible since the latest version according the documentation. 我花了很多时间在书籍和网络上搜索,到目前为止,我只知道从文档的最新版本开始是可能的。
ML ML
( Recommended in Spark 2.0+ ) ( 在Spark 2.0+中推荐 )
We'll use the same data as in the MLlib below. 我们将使用与下面的MLlib中相同的数据。 There are two basic options. 有两个基本选项。 If Estimator
supports multilclass classification out-of-the-box (for example random forest) you can use it directly: 如果Estimator
支持开箱即用的多类分类(例如随机森林),则可以直接使用它:
val trainRawDf = trainRaw.toDF
import org.apache.spark.ml.feature.{Tokenizer, CountVectorizer, StringIndexer}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
val transformers = Array(
new StringIndexer().setInputCol("group").setOutputCol("label"),
new Tokenizer().setInputCol("text").setOutputCol("tokens"),
new CountVectorizer().setInputCol("tokens").setOutputCol("features")
)
val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
val model = new Pipeline().setStages(transformers :+ rf).fit(trainRawDf)
model.transform(trainRawDf)
If model supports only binary classification (logistic regression) and extends oasml.classification.Classifier
you can use one-vs-rest strategy: 如果模型仅支持二进制分类(逻辑回归)并扩展了oasml.classification.Classifier
,则可以使用one-vs-rest策略:
import org.apache.spark.ml.classification.OneVsRest
import org.apache.spark.ml.classification.LogisticRegression
val lr = new LogisticRegression()
.setLabelCol("label")
.setFeaturesCol("features")
val ovr = new OneVsRest().setClassifier(lr)
val ovrModel = new Pipeline().setStages(transformers :+ ovr).fit(trainRawDf)
MLLib 多层板
According to the official documentation at this moment (MLlib 1.6.0) following methods support multiclass classification: 根据目前的官方文档 (MLlib 1.6.0),以下方法支持多类分类:
At least some of the examples use multiclass classification: 至少某些示例使用多类分类:
General framework, ignoring method specific arguments, is pretty much the same as for all the other methods in MLlib. 忽略方法特定参数的通用框架与MLlib中的所有其他方法几乎相同。 You have to pre-processes your input to create either data frame with columns representing label
and features
: 您必须对输入进行预处理,以创建带有表示label
和features
列的任一数据框:
root
|-- label: double (nullable = true)
|-- features: vector (nullable = true)
or RDD[LabeledPoint]
. 或RDD[LabeledPoint]
。
Spark provides broad range of useful tools designed to facilitate this process including Feature Extractors and Feature Transformers and pipelines . Spark提供了广泛的有用工具,旨在促进此过程,包括功能提取器 , 功能转换器和管道 。
You'll find a rather naive example of using Random Forest below. 您会在下面找到一个使用“随机森林”的简单示例。
First lets import required packages and create dummy data: 首先,让我们导入所需的软件包并创建虚拟数据:
import sqlContext.implicits._
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.Row
import org.apache.spark.rdd.RDD
case class LabeledRecord(group: String, text: String)
val trainRaw = sc.parallelize(
LabeledRecord("foo", "foo v a y b foo") ::
LabeledRecord("bar", "x bar y bar v") ::
LabeledRecord("bar", "x a y bar z") ::
LabeledRecord("foobar", "foo v b bar z") ::
LabeledRecord("foo", "foo x") ::
LabeledRecord("foobar", "z y x foo a b bar v") ::
Nil
)
Now let's define required transformers and process train Dataset
: 现在让我们定义所需的变压器和过程Dataset
:
// Tokenizer to process text fields
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
// HashingTF to convert tokens to the feature vector
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(10)
// Indexer to convert String labels to Double
val indexer = new StringIndexer()
.setInputCol("group")
.setOutputCol("label")
.fit(trainRaw.toDF)
def transfom(rdd: RDD[LabeledRecord]) = {
val tokenized = tokenizer.transform(rdd.toDF)
val hashed = hashingTF.transform(tokenized)
val indexed = indexer.transform(hashed)
indexed
.select($"label", $"features")
.map{case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)}
}
val train: RDD[LabeledPoint] = transfom(trainRaw)
Please note that indexer
is "fitted" on the train data. 请注意, indexer
已“适合”火车数据。 It simply means that categorical values used as the labels are converted to doubles
. 它仅表示将用作标签的分类值转换为doubles
。 To use classifier on a new data you have to transform it first using this indexer
. 要对新数据使用分类器,必须首先使用此indexer
对其进行转换。
Next we can train RF model: 接下来我们可以训练RF模型:
val numClasses = 3
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 10
val featureSubsetStrategy = "auto"
val impurity = "gini"
val maxDepth = 4
val maxBins = 16
val model = RandomForest.trainClassifier(
train, numClasses, categoricalFeaturesInfo,
numTrees, featureSubsetStrategy, impurity,
maxDepth, maxBins
)
and finally test it: 最后测试一下:
val testRaw = sc.parallelize(
LabeledRecord("foo", "foo foo z z z") ::
LabeledRecord("bar", "z bar y y v") ::
LabeledRecord("bar", "a a bar a z") ::
LabeledRecord("foobar", "foo v b bar z") ::
LabeledRecord("foobar", "a foo a bar") ::
Nil
)
val test: RDD[LabeledPoint] = transfom(testRaw)
val predsAndLabs = test.map(lp => (model.predict(lp.features), lp.label))
val metrics = new MulticlassMetrics(predsAndLabs)
metrics.precision
metrics.recall
Are you using Spark 1.6 rather than Spark 2.1? 您使用的是Spark 1.6而不是Spark 2.1? I think the problem is that in spark 2.1 the transform method returns a dataset, which can be implicitly converted to a typed RDD, where as prior to that, it returns a data frame or row. 我认为问题在于,在spark 2.1中,transform方法返回一个数据集,该数据集可以隐式转换为类型化的RDD,在此之前,它返回数据帧或行。
Try as a diagnostic specifying the return type of the transform function as RDD[LabeledPoint] and see if you get the same error. 尝试作为诊断,将转换函数的返回类型指定为RDD [LabeledPoint],看看是否遇到相同的错误。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.