繁体   English   中英

试图了解此Python函数中正在发生的事情

[英]Trying to understand what is happening in this Python Function

def closest_centroid(points, centroids):
    """returns an array containing the index to the nearest centroid for each point"""
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

有人可以解释此功能的确切作用吗? 我目前得到的points看起来像:

31998888119     0.94     34
23423423422     0.45     43
....

等等。 在此numpy数组中,对于第一个条目, points[1]将是长ID,而points[2]0.94points[3]将为34

质心只是从此特定数组中的随机选择:

def initialize_centroids(points, k):
    """returns k centroids from the initial points"""
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k] 

现在,我想从忽略ID和centroids的第一列的points的值中获取欧几里得距离(再次忽略第一列)。 我从distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))不完全理解语法。 我们为什么要精确地对第三列进行求和,而又有一个新轴: np.newaxis 我还应该沿着哪个轴使np.argmin工作?

这有助于考虑尺寸。 假设k=4并且有10个点,那么points.shape = (10,3)

接下来, centroids = initialize_centroids(points, 4)返回尺寸为(4,3)的对象。

让我们从内部分解这一行:

distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))

  1. 我们想从每个点减去每个质心。 由于pointscentroids是二维的,因此每个points - centroid都是二维的。 如果只有1个质心,那我们就可以了。 但是我们有4个质心! 因此,我们需要为每个质心执行points - centroids重心。 因此,我们需要另一个维度来存储它。 因此增加了一个np.newaxis

  2. 我们对它求平方是因为它是一个距离,因此我们想将负值转换为正值(并且因为我们正在最小化欧几里得距离)。

  3. 我们不会在第三列进行汇总。 实际上,对于每个点,每个质心,我们都在求和点与质心之间的差。

  4. np.argmin()查找具有最小距离的质心。 因此,对于每个质心,对于每个点,都找到最小索引(因此用argmin代替min )。 该索引是分配给该点的质心。

这是一个例子:

points = np.array([
[   1, 2, 4],
[   1, 1, 3],
[   1, 6, 2],
[   6, 2, 3],
[   7, 2, 3],
[   1, 9, 6],
[   6, 9, 1],
[   3, 8, 6],
[   10, 9, 6],
[   0, 2, 0],
])

centroids = initialize_centroids(points, 4)

print(centroids)
array([[10,  9,  6],
   [ 3,  8,  6],
   [ 6,  2,  3],
   [ 1,  1,  3]])

distances = (pts - centroids[:, np.newaxis])**2

print(distances)
array([[[ 81,  49,   4],
    [ 81,  64,   9],
    [ 81,   9,  16],
    [ 16,  49,   9],
    [  9,  49,   9],
    [ 81,   0,   0],
    [ 16,   0,  25],
    [ 49,   1,   0],
    [  0,   0,   0],
    [100,  49,  36]],

   [[  4,  36,   4],
    [  4,  49,   9],
    [  4,   4,  16],
    [  9,  36,   9],
    [ 16,  36,   9],
    [  4,   1,   0],
    [  9,   1,  25],
    [  0,   0,   0],
    [ 49,   1,   0],
    [  9,  36,  36]],

   [[ 25,   0,   1],
    [ 25,   1,   0],
    [ 25,  16,   1],
    [  0,   0,   0],
    [  1,   0,   0],
    [ 25,  49,   9],
    [  0,  49,   4],
    [  9,  36,   9],
    [ 16,  49,   9],
    [ 36,   0,   9]],

   [[  0,   1,   1],
    [  0,   0,   0],
    [  0,  25,   1],
    [ 25,   1,   0],
    [ 36,   1,   0],
    [  0,  64,   9],
    [ 25,  64,   4],
    [  4,  49,   9],
    [ 81,  64,   9],
    [  1,   1,   9]]])

print(distances.sum(axis=2))
array([[134, 154, 106,  74,  67,  81,  41,  50,   0, 185],
   [ 44,  62,  24,  54,  61,   5,  35,   0,  50,  81],
   [ 26,  26,  42,   0,   1,  83,  53,  54,  74,  45],
   [  2,   0,  26,  26,  37,  73,  93,  62, 154,  11]])

# The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again.

print(np.argmin(distances.sum(axis=2), axis=0))
array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM