[英]How to extract and map cluster indices from sklearn.cluster.KMeans?
我有一張數據地圖:
import seaborn as sns
import matplotlib.pyplot as plt
X = 101_by_99_float32_array
ax = sns.heatmap(X, square = True)
plt.show()
請注意,這些數據本質上是一個 3D 表面,我對聚類后X
中的索引位置感興趣。 我可以輕松地將 kmeans 算法應用於我的數據:
from sklearn.cluster import KMeans
# three clusters is arbitrary; just used for testing purposes
k_means = KMeans(init='k-means++', n_clusters=3, n_init=10).fit(X)
但我不確定如何以一種方式導航kmeans
,以識別上面地圖中的像素屬於哪個集群。 我想要做的是制作一個看起來像上面那個的地圖,但不是為 100x99 數組X
中的每個單元格繪制 z 值,我想繪制X
每個單元格的簇號。
我不知道這是可能的k均值算法的輸出,但我也嘗試從scikitlearn文件的方法在這里:
import numpy as np
k_means_labels = k_means.labels_
k_means_cluster_centers = k_means.cluster_centers_
k_means_labels_unique = np.unique(k_means_labels)
colors = ['#4EACC5', '#FF9C34', '#4E9A06']
plt.figure()
#plt.hold(True)
for k, col in zip(range(3), colors):
my_members = k_means_labels == k
cluster_center = k_means_cluster_centers[k]
plt.plot(X[my_members, 0], X[my_members, 1], 'w',
markerfacecolor=col, marker='.')
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
markeredgecolor='k', markersize=6)
plt.title('KMeans')
plt.show()
但很明顯這不是訪問我想要的信息......
很明顯,我並沒有完全理解kmeans
輸出的每個組成部分代表什么,並且我試圖閱讀這里找到的問題的答案中的解釋。 但是,該答案中沒有任何內容明確說明聚類后原始數據的索引是否保留,這確實是我問題的核心。 如果這些信息通過一些矩陣乘法隱含在kmeans
,我真的可以使用一些幫助來提取它。
感謝您的時間和幫助!
編輯:
感謝@Nakor,對 kmeans 的解釋和重塑我的數據的建議。 kmeans
如何解釋我的數據現在更加清晰。 我不應該期望它捕獲每個樣本的索引,而是依靠reshape
來做到這一點。 reshape
將ravel
原始(101,99)矩陣為(9999,1)陣列,正如@Nakor指出的,適合於聚類的每個條目作為單獨的樣品。
只需使用數據的原始形狀將reshape
應用於kmeans.labels_
,我就得到了我正在尋找的結果:
Y = X.reshape(-1, 1) # shape data to cluster each individual entry
kmeans= KMeans(init='k-means++', n_clusters=3, n_init=10)
kmeans.fit(Y)
Z = kmeans.labels_
A = Z.reshape(101,99)
plt.figure()
ax = sns.heatmap(cu_map, square = True)
plt.figure()
ay = sns.heatmap(A, square = True)
您的問題是sklearn.cluster.KMeans
需要帶有[N_samples,N_features]
的二維矩陣。 但是,您提供了原始圖像,因此 sklearn 知道您有 101 個樣本,每個樣本有 99 個特征(圖像的每一行都是一個樣本,列是特征)。 結果,您在k_means.labels_
得到的是每一行的集群分配。
如果您想對每個條目進行聚類,則需要像這樣重塑您的數據,例如:
model = KMeans(init='k-means++', n_clusters=3, n_init=10)
model.fit(X.reshape(-1,1))
如果我檢查隨機生成的數據,我會得到:
In [1]: len(model.labels_)
Out[1]: 9999
我每個條目有一個標簽。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.