[英]How to extract and map cluster indices from sklearn.cluster.KMeans?
我有一张数据地图:
import seaborn as sns
import matplotlib.pyplot as plt
X = 101_by_99_float32_array
ax = sns.heatmap(X, square = True)
plt.show()
请注意,这些数据本质上是一个 3D 表面,我对聚类后X
中的索引位置感兴趣。 我可以轻松地将 kmeans 算法应用于我的数据:
from sklearn.cluster import KMeans
# three clusters is arbitrary; just used for testing purposes
k_means = KMeans(init='k-means++', n_clusters=3, n_init=10).fit(X)
但我不确定如何以一种方式导航kmeans
,以识别上面地图中的像素属于哪个集群。 我想要做的是制作一个看起来像上面那个的地图,但不是为 100x99 数组X
中的每个单元格绘制 z 值,我想绘制X
每个单元格的簇号。
我不知道这是可能的k均值算法的输出,但我也尝试从scikitlearn文件的方法在这里:
import numpy as np
k_means_labels = k_means.labels_
k_means_cluster_centers = k_means.cluster_centers_
k_means_labels_unique = np.unique(k_means_labels)
colors = ['#4EACC5', '#FF9C34', '#4E9A06']
plt.figure()
#plt.hold(True)
for k, col in zip(range(3), colors):
my_members = k_means_labels == k
cluster_center = k_means_cluster_centers[k]
plt.plot(X[my_members, 0], X[my_members, 1], 'w',
markerfacecolor=col, marker='.')
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
markeredgecolor='k', markersize=6)
plt.title('KMeans')
plt.show()
但很明显这不是访问我想要的信息......
很明显,我并没有完全理解kmeans
输出的每个组成部分代表什么,并且我试图阅读这里找到的问题的答案中的解释。 但是,该答案中没有任何内容明确说明聚类后原始数据的索引是否保留,这确实是我问题的核心。 如果这些信息通过一些矩阵乘法隐含在kmeans
,我真的可以使用一些帮助来提取它。
感谢您的时间和帮助!
编辑:
感谢@Nakor,对 kmeans 的解释和重塑我的数据的建议。 kmeans
如何解释我的数据现在更加清晰。 我不应该期望它捕获每个样本的索引,而是依靠reshape
来做到这一点。 reshape
将ravel
原始(101,99)矩阵为(9999,1)阵列,正如@Nakor指出的,适合于聚类的每个条目作为单独的样品。
只需使用数据的原始形状将reshape
应用于kmeans.labels_
,我就得到了我正在寻找的结果:
Y = X.reshape(-1, 1) # shape data to cluster each individual entry
kmeans= KMeans(init='k-means++', n_clusters=3, n_init=10)
kmeans.fit(Y)
Z = kmeans.labels_
A = Z.reshape(101,99)
plt.figure()
ax = sns.heatmap(cu_map, square = True)
plt.figure()
ay = sns.heatmap(A, square = True)
您的问题是sklearn.cluster.KMeans
需要带有[N_samples,N_features]
的二维矩阵。 但是,您提供了原始图像,因此 sklearn 知道您有 101 个样本,每个样本有 99 个特征(图像的每一行都是一个样本,列是特征)。 结果,您在k_means.labels_
得到的是每一行的集群分配。
如果您想对每个条目进行聚类,则需要像这样重塑您的数据,例如:
model = KMeans(init='k-means++', n_clusters=3, n_init=10)
model.fit(X.reshape(-1,1))
如果我检查随机生成的数据,我会得到:
In [1]: len(model.labels_)
Out[1]: 9999
我每个条目有一个标签。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.