简体   繁体   English

Java:KMeans 将聚类中心加入预测 Dataframe

[英]Java : KMeans Spark Join Cluster Center into Prediction Dataframe

I know that KMeansModel transform give us output as Dataset and there is a prediction column of the output dataframe stating which column a _c0 , _c1 , features , and prediction .我知道 KMeansModel transform为我们提供 output 作为数据集,并且有一个预测列_c0 dataframe 说明哪一列, features ,和prediction _c1

However, I'd also like to know each cluster center for each features into this dataframe.但是,我也想知道这个 dataframe 中每个特征的每个聚类中心。

How can I do it using java?如何使用 java 做到这一点? i'd like my code is pure java (not in scala even spark based on scala).我希望我的代码是纯 java (不在 scala 中,甚至基于 scala 的火花)。

+-----------+----------+
|   features|prediction|
+-----------+----------+
| [4.0,53.0]|         2|
| [5.0,63.0]|         3|
|[10.0,59.0]|         2|
|[13.0,49.0]|         0|
|[12.0,88.0]|         1|
|[12.0,88.0]|         1|
|[18.0,61.0]|         2|
+-----------+----------+

Expectation Result:预期结果:

+-----------+----------+----------------+
|   features|prediction|clusterCenter   |
+-----------+----------+-------------+--+
| [4.0,53.0]|         1| [20.15,64.95]  |
| [5.0,63.0]|         0| [43.91,146.04  |
|[10.0,59.0]|         2| [20.4,68]      |
|[13.0,49.0]|         0| [43.91,146.04] |
|[12.0,88.0]|         1| [20.15,64.95]  |
|[12.0,88.0]|         2| [20.4,68]      |
|[18.0,61.0]|         3| [98.176,114.88]|
+-----------+----------+----------------+

Here some test code snippet这里有一些测试代码片段

List<Row> dataset = Arrays.asList(
     RowFactory.create(Vectors.dense(4,53)),
     RowFactory.create(Vectors.dense(5,63)),
     RowFactory.create(Vectors.dense(10,59)),
     RowFactory.create(Vectors.dense(13,49)),
     RowFactory.create(Vectors.dense( 12,88)),
     RowFactory.create(Vectors.dense(12,88)),
     RowFactory.create(Vectors.dense(18,61))
);

StructType schema = new StructType(new StructField[]{
     new StructField("features", new VectorUDT(), false, Metadata.empty()),
});

Dataset<Row> df = sc.createDataFrame(dataset, schema);
KMeans kMeans = new KMeans().setK(4).setMaxIter(10);
KMeansModel model = kMeans.fit(df);
Dataset<Row> predict = model.transform(df);
predict.show();

StructType centroidSchema = new StructType(new StructField[]{
    new StructField("x", DataTypes.StringType, false, Metadata.empty()),
    new StructField("y", DataTypes.StringType, false, Metadata.empty())
});

Dataset<Row> centroid = sc.createDataFrame(jsc.parallelize(model.clusterCenters()).map(s -> {
     String[] row = s.toString().replace("[","").replace("]","").split(",");
     return RowFactory.create((Object[]) row);
}), centroidSchema);
centroid = centroid .withColumn("x", centroid .col("x").cast("Double"));
centroid = centroid .withColumn("y", centroid .col("y").cast("Double"));
centroid.show();

I've resolve this questions with below approach:我已经用以下方法解决了这个问题:

...
Vector[] centroidVector = model.clusterCenters();
Dataset<Row> prediction = model.transform(dataset);

List<Tuple2<String, Integer>> centroid = new ArrayList<>();
for (int i=0; i<centroidVector.length; i++){
    centroid.add(new Tuple2<>(Arrays.toString(centroidVector[i].toArray()), i));
}

JavaPairRDD<String, Integer> sorting = jsc.parallelize(centroid).mapToPair((Tuple2<String, Integer> s) -> new Tuple2<>(s._1, s._2)).cache();
Dataset<Row> df = sc.createDataset(sorting.collect(), Encoders.tuple(Encoders.STRING(), Encoders.INT())).toDF("centroid","prediction").cache();

Dataset<Row> result = prediction.join(df, df.col("prediction").equalTo(prediction.col("prediction"))).drop(prediction.col("prediction"));
result=result.withColumn("features", result.col("features").cast("String"));
result=result.withColumn("centroid", result.col("centroid").cast("String"));
result.show();

Result with ruspini dataset: ruspini 数据集的结果:

+-----+-----+-------------+--------------------+----------+
|   f1|   f2|     features|            centroid|prediction|
+-----+-----+-------------+--------------------+----------+
|  4.0| 53.0|   [4.0,53.0]|[20.1500000000000...|         0|
|  5.0| 63.0|   [5.0,63.0]|[20.1500000000000...|         0|
| 10.0| 59.0|  [10.0,59.0]|[20.1500000000000...|         0|
| 32.0| 61.0|  [32.0,61.0]|[20.1500000000000...|         0|
| 28.0|147.0| [28.0,147.0]|[43.9130434782608...|         1|
| 32.0|149.0| [32.0,149.0]|[43.9130434782608...|         1|
| 41.0|150.0| [41.0,150.0]|[43.9130434782608...|         1|
| 52.0|152.0| [52.0,152.0]|[43.9130434782608...|         1|
| 86.0|132.0| [86.0,132.0]|[98.1764705882352...|         3|
| 85.0|115.0| [85.0,115.0]|[98.1764705882352...|         3|
| 85.0| 96.0|  [85.0,96.0]|[98.1764705882352...|         3|
| 78.0| 94.0|  [78.0,94.0]|[98.1764705882352...|         3|
| 70.0|  4.0|   [70.0,4.0]|[68.9333333333333...|         2|
| 77.0| 12.0|  [77.0,12.0]|[68.9333333333333...|         2|
| 83.0| 21.0|  [83.0,21.0]|[68.9333333333333...|         2|
| 61.0| 15.0|  [61.0,15.0]|[68.9333333333333...|         2|
+-----+-----+-------------+--------------------+----------+

If anyone have another approach i'd like to hear that.如果有人有另一种方法,我想听听。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM