[英]Java : KMeans Spark Join Cluster Center into Prediction Dataframe
我知道 KMeansModel transform
为我们提供 output 作为数据集,并且有一个预测列_c0
dataframe 说明哪一列, features
,和prediction
_c1
但是,我也想知道这个 dataframe 中每个特征的每个聚类中心。
如何使用 java 做到这一点? 我希望我的代码是纯 java (不在 scala 中,甚至基于 scala 的火花)。
+-----------+----------+
| features|prediction|
+-----------+----------+
| [4.0,53.0]| 2|
| [5.0,63.0]| 3|
|[10.0,59.0]| 2|
|[13.0,49.0]| 0|
|[12.0,88.0]| 1|
|[12.0,88.0]| 1|
|[18.0,61.0]| 2|
+-----------+----------+
预期结果:
+-----------+----------+----------------+
| features|prediction|clusterCenter |
+-----------+----------+-------------+--+
| [4.0,53.0]| 1| [20.15,64.95] |
| [5.0,63.0]| 0| [43.91,146.04 |
|[10.0,59.0]| 2| [20.4,68] |
|[13.0,49.0]| 0| [43.91,146.04] |
|[12.0,88.0]| 1| [20.15,64.95] |
|[12.0,88.0]| 2| [20.4,68] |
|[18.0,61.0]| 3| [98.176,114.88]|
+-----------+----------+----------------+
这里有一些测试代码片段
List<Row> dataset = Arrays.asList(
RowFactory.create(Vectors.dense(4,53)),
RowFactory.create(Vectors.dense(5,63)),
RowFactory.create(Vectors.dense(10,59)),
RowFactory.create(Vectors.dense(13,49)),
RowFactory.create(Vectors.dense( 12,88)),
RowFactory.create(Vectors.dense(12,88)),
RowFactory.create(Vectors.dense(18,61))
);
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = sc.createDataFrame(dataset, schema);
KMeans kMeans = new KMeans().setK(4).setMaxIter(10);
KMeansModel model = kMeans.fit(df);
Dataset<Row> predict = model.transform(df);
predict.show();
StructType centroidSchema = new StructType(new StructField[]{
new StructField("x", DataTypes.StringType, false, Metadata.empty()),
new StructField("y", DataTypes.StringType, false, Metadata.empty())
});
Dataset<Row> centroid = sc.createDataFrame(jsc.parallelize(model.clusterCenters()).map(s -> {
String[] row = s.toString().replace("[","").replace("]","").split(",");
return RowFactory.create((Object[]) row);
}), centroidSchema);
centroid = centroid .withColumn("x", centroid .col("x").cast("Double"));
centroid = centroid .withColumn("y", centroid .col("y").cast("Double"));
centroid.show();
我已经用以下方法解决了这个问题:
...
Vector[] centroidVector = model.clusterCenters();
Dataset<Row> prediction = model.transform(dataset);
List<Tuple2<String, Integer>> centroid = new ArrayList<>();
for (int i=0; i<centroidVector.length; i++){
centroid.add(new Tuple2<>(Arrays.toString(centroidVector[i].toArray()), i));
}
JavaPairRDD<String, Integer> sorting = jsc.parallelize(centroid).mapToPair((Tuple2<String, Integer> s) -> new Tuple2<>(s._1, s._2)).cache();
Dataset<Row> df = sc.createDataset(sorting.collect(), Encoders.tuple(Encoders.STRING(), Encoders.INT())).toDF("centroid","prediction").cache();
Dataset<Row> result = prediction.join(df, df.col("prediction").equalTo(prediction.col("prediction"))).drop(prediction.col("prediction"));
result=result.withColumn("features", result.col("features").cast("String"));
result=result.withColumn("centroid", result.col("centroid").cast("String"));
result.show();
ruspini 数据集的结果:
+-----+-----+-------------+--------------------+----------+
| f1| f2| features| centroid|prediction|
+-----+-----+-------------+--------------------+----------+
| 4.0| 53.0| [4.0,53.0]|[20.1500000000000...| 0|
| 5.0| 63.0| [5.0,63.0]|[20.1500000000000...| 0|
| 10.0| 59.0| [10.0,59.0]|[20.1500000000000...| 0|
| 32.0| 61.0| [32.0,61.0]|[20.1500000000000...| 0|
| 28.0|147.0| [28.0,147.0]|[43.9130434782608...| 1|
| 32.0|149.0| [32.0,149.0]|[43.9130434782608...| 1|
| 41.0|150.0| [41.0,150.0]|[43.9130434782608...| 1|
| 52.0|152.0| [52.0,152.0]|[43.9130434782608...| 1|
| 86.0|132.0| [86.0,132.0]|[98.1764705882352...| 3|
| 85.0|115.0| [85.0,115.0]|[98.1764705882352...| 3|
| 85.0| 96.0| [85.0,96.0]|[98.1764705882352...| 3|
| 78.0| 94.0| [78.0,94.0]|[98.1764705882352...| 3|
| 70.0| 4.0| [70.0,4.0]|[68.9333333333333...| 2|
| 77.0| 12.0| [77.0,12.0]|[68.9333333333333...| 2|
| 83.0| 21.0| [83.0,21.0]|[68.9333333333333...| 2|
| 61.0| 15.0| [61.0,15.0]|[68.9333333333333...| 2|
+-----+-----+-------------+--------------------+----------+
如果有人有另一种方法,我想听听。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.