[英]Fast way to apply function to each row of a numpy array
假设我有一些最近的邻居分类器。 对于新观察,它计算新观察与“已知”数据集中所有观察之间的距离。 它返回观测值的类标签,该类标签到新观测值的距离最小。
import numpy as np
known_obs = np.random.randint(0, 10, 40).reshape(8, 5)
new_obs = np.random.randint(0, 10, 80).reshape(16, 5)
labels = np.random.randint(0, 2, 8).reshape(8, )
def my_dist(x1, known_obs, axis=0):
return (np.square(np.linalg.norm(x1 - known_obs, axis=axis)))
def nn_classifier(n, known_obs, labels, axis=1, distance=my_dist):
return labels[np.argmin(distance(n, known_obs, axis=axis))]
def classify_batch(new_obs, known_obs, labels, classifier=nn_classifier, distance=my_dist):
return [classifier(n, known_obs, labels, distance=distance) for n in new_obs]
print(classify_batch(new_obs, known_obs, labels, nn_classifier, my_dist))
出于性能原因,我想避免classify_batch函数中的for循环。 有没有一种方法可以使用numpy操作将nn_classifier函数应用于new_obs的每一行? 我已经尝试过apply_along_axis,但是经常提到它很方便,但并不快。
避免循环的关键是在“距离”(16,8)数组上表达动作。 labels[]
和argmin
步骤只会使问题argmin
。
如果我将labels = np.arange(8)
设置labels = np.arange(8)
,则此
arr = np.array([my_dist(n, known_obs, axis=1) for n in new_obs])
print(arr)
print(np.argmin(arr, axis=1))
产生相同的东西。 它仍然具有列表理解功能,但我们更接近“源”。
[[ 32. 115. 22. 116. 162. 86. 161. 117.]
[ 106. 31. 142. 164. 92. 106. 45. 103.]
[ 44. 135. 94. 18. 94. 50. 87. 135.]
[ 11. 92. 57. 67. 79. 43. 118. 106.]
[ 40. 67. 126. 98. 50. 74. 75. 175.]
[ 78. 61. 120. 148. 102. 128. 67. 191.]
[ 51. 48. 57. 133. 125. 35. 110. 14.]
[ 47. 28. 93. 91. 63. 49. 32. 88.]
[ 61. 86. 23. 141. 159. 85. 146. 22.]
[ 131. 70. 155. 149. 129. 127. 44. 138.]
[ 97. 138. 87. 117. 223. 77. 130. 122.]
[ 151. 78. 211. 161. 131. 115. 46. 164.]
[ 13. 50. 31. 69. 59. 43. 80. 40.]
[ 131. 108. 157. 161. 207. 85. 102. 146.]
[ 39. 106. 67. 23. 61. 67. 70. 88.]
[ 54. 51. 74. 68. 42. 86. 35. 65.]]
[2 1 3 0 0 1 7 1 7 6 5 6 0 5 3 6]
用
print((new_obs[:,None,:] - known_obs[None,:,:]).shape)
我得到了(16,8,5)数组。 那么我可以在最后一个轴上应用linalg.norm
吗?
这似乎可以解决问题
np.square(np.linalg.norm(diff, axis=-1))
所以在一起:
diff = (new_obs[:,None,:] - known_obs[None,:,:])
dist = np.square(np.linalg.norm(diff, axis=-1))
idx = np.argmin(dist, axis=1)
print(idx)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.