[英]Trying to understand what is happening in this Python Function
def closest_centroid(points, centroids):
"""returns an array containing the index to the nearest centroid for each point"""
distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
return np.argmin(distances, axis=0)
有人可以解释此功能的确切作用吗? 我目前得到的points
看起来像:
31998888119 0.94 34
23423423422 0.45 43
....
等等。 在此numpy
数组中,对于第一个条目, points[1]
将是长ID,而points[2]
为0.94
, points[3]
将为34
。
质心只是从此特定数组中的随机选择:
def initialize_centroids(points, k):
"""returns k centroids from the initial points"""
centroids = points.copy()
np.random.shuffle(centroids)
return centroids[:k]
现在,我想从忽略ID和centroids
的第一列的points
的值中获取欧几里得距离(再次忽略第一列)。 我从distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
不完全理解语法。 我们为什么要精确地对第三列进行求和,而又有一个新轴: np.newaxis
? 我还应该沿着哪个轴使np.argmin
工作?
这有助于考虑尺寸。 假设k=4
并且有10个点,那么points.shape = (10,3)
。
接下来, centroids = initialize_centroids(points, 4)
返回尺寸为(4,3)
的对象。
让我们从内部分解这一行:
distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
我们想从每个点减去每个质心。 由于points
和centroids
是二维的,因此每个points - centroid
都是二维的。 如果只有1个质心,那我们就可以了。 但是我们有4个质心! 因此,我们需要为每个质心执行points - centroids
重心。 因此,我们需要另一个维度来存储它。 因此增加了一个np.newaxis
。
我们对它求平方是因为它是一个距离,因此我们想将负值转换为正值(并且因为我们正在最小化欧几里得距离)。
我们不会在第三列进行汇总。 实际上,对于每个点,每个质心,我们都在求和点与质心之间的差。
np.argmin()
查找具有最小距离的质心。 因此,对于每个质心,对于每个点,都找到最小索引(因此用argmin
代替min
)。 该索引是分配给该点的质心。
这是一个例子:
points = np.array([
[ 1, 2, 4],
[ 1, 1, 3],
[ 1, 6, 2],
[ 6, 2, 3],
[ 7, 2, 3],
[ 1, 9, 6],
[ 6, 9, 1],
[ 3, 8, 6],
[ 10, 9, 6],
[ 0, 2, 0],
])
centroids = initialize_centroids(points, 4)
print(centroids)
array([[10, 9, 6],
[ 3, 8, 6],
[ 6, 2, 3],
[ 1, 1, 3]])
distances = (pts - centroids[:, np.newaxis])**2
print(distances)
array([[[ 81, 49, 4],
[ 81, 64, 9],
[ 81, 9, 16],
[ 16, 49, 9],
[ 9, 49, 9],
[ 81, 0, 0],
[ 16, 0, 25],
[ 49, 1, 0],
[ 0, 0, 0],
[100, 49, 36]],
[[ 4, 36, 4],
[ 4, 49, 9],
[ 4, 4, 16],
[ 9, 36, 9],
[ 16, 36, 9],
[ 4, 1, 0],
[ 9, 1, 25],
[ 0, 0, 0],
[ 49, 1, 0],
[ 9, 36, 36]],
[[ 25, 0, 1],
[ 25, 1, 0],
[ 25, 16, 1],
[ 0, 0, 0],
[ 1, 0, 0],
[ 25, 49, 9],
[ 0, 49, 4],
[ 9, 36, 9],
[ 16, 49, 9],
[ 36, 0, 9]],
[[ 0, 1, 1],
[ 0, 0, 0],
[ 0, 25, 1],
[ 25, 1, 0],
[ 36, 1, 0],
[ 0, 64, 9],
[ 25, 64, 4],
[ 4, 49, 9],
[ 81, 64, 9],
[ 1, 1, 9]]])
print(distances.sum(axis=2))
array([[134, 154, 106, 74, 67, 81, 41, 50, 0, 185],
[ 44, 62, 24, 54, 61, 5, 35, 0, 50, 81],
[ 26, 26, 42, 0, 1, 83, 53, 54, 74, 45],
[ 2, 0, 26, 26, 37, 73, 93, 62, 154, 11]])
# The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again.
print(np.argmin(distances.sum(axis=2), axis=0))
array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.