簡體   English   中英

獲取離質心最近的點,scikit-learn?

[英]Get nearest point to centroid, scikit-learn?

我正在使用 K-means 來解決聚類問題。 我試圖找到最接近質心的數據點,我相信它被稱為 medoid。

有沒有辦法在 scikit-learn 中做到這一點?

這不是medoid,但您可以嘗試以下方法:

>>> import numpy as np
>>> from sklearn.cluster import KMeans
>>> from sklearn.metrics import pairwise_distances_argmin_min
>>> X = np.random.randn(10, 4)
>>> km = KMeans(n_clusters=2).fit(X)
>>> closest, _ = pairwise_distances_argmin_min(km.cluster_centers_, X)
>>> closest
array([0, 8])

數組closest包含X中最接近每個質心的點的索引。 所以X[0]X離質心0最近的點, X[8]是離質心1最近的點。

我嘗試了上面的答案,但它給了我重復的結果。 無論聚類結果如何,以上都會找到最近的數據點。 因此它可以返回同一個集群的副本。

如果您想在中心指示的同一集群中找到最接近的數據,請嘗試此操作。

該解決方案給出的數據點來自所有不同的集群,並且返回的數據點數量與集群數量相同。

import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min

# assume the total number of data is 100
all_data = [ i for i in range(100) ]
tf_matrix = numpy.random.random((100, 100))

# set your own number of clusters
num_clusters = 2

m_km = KMeans(n_clusters=num_clusters)  
m_km.fit(tf_matrix)
m_clusters = m_km.labels_.tolist()

centers = np.array(m_km.cluster_centers_)

closest_data = []
for i in range(num_clusters):
    center_vec = centers[i]
    data_idx_within_i_cluster = [ idx for idx, clu_num in enumerate(m_clusters) if clu_num == i ]

    one_cluster_tf_matrix = np.zeros( (  len(data_idx_within_i_cluster) , centers.shape[1] ) )
    for row_num, data_idx in enumerate(data_idx_within_i_cluster):
        one_row = tf_matrix[data_idx]
        one_cluster_tf_matrix[row_num] = one_row

    closest, _ = pairwise_distances_argmin_min(center_vec, one_cluster_tf_matrix)
    closest_idx_in_one_cluster_tf_matrix = closest[0]
    closest_data_row_num = data_idx_within_i_cluster[closest_idx_in_one_cluster_tf_matrix]
    data_id = all_data[closest_data_row_num]

    closest_data.append(data_id)

closest_data = list(set(closest_data))

assert len(closest_data) == num_clusters

您想要實現的基本上是矢量量化,但是“反向”。 Scipy有一個非常優化的功能,比提到的其他方法快得多。 輸出與pairwise_distances_argmin_min()相同。

    from scipy.cluster.vq import vq

    # centroids: N-dimensional array with your centroids
    # points:    N-dimensional array with your data points

    closest, distances = vq(centroids, points)

當您使用非常大的數組執行它時,就會有很大的不同,我使用 100000+ 個點和 65000+ 個質心的數組來執行它,這種方法比scikit 中的pairwise_distances_argmin_min ()快 4 倍,如下所示:

     start_time = time.time()
     cl2, dst2 = vq(centroids, points)
     print("--- %s seconds ---" % (time.time() - start_time))
     --- 32.13545227050781 seconds ---

     start_time = time.time()
     cl2, dst2 = pairwise_distances_argmin_min(centroids, points)
     print("--- %s seconds ---" % (time.time() - start_time))
     --- 131.21064710617065 seconds ---

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM