I have the following python function:
def npnearest(u: np.ndarray, X: np.ndarray, Y: np.ndarray, distance: 'callbale'=npdistance):
'''
Finds x1 so that x1 is in X and u and x1 have a minimal distance (according to the
provided distance function) compared to all other data points in X. Returns the label of x1
Args:
u (np.ndarray): The vector (ndim=1) we want to classify
X (np.ndarray): A matrix (ndim=2) with training data points (vectors)
Y (np.ndarray): A vector containing the label of each data point in X
distance (callable): A function that receives two inputs and defines the distance function used
Returns:
int: The label of the data point which is closest to `u`
'''
xbest = None
ybest = None
dbest = float('inf')
for x, y in zip(X, Y):
d = distance(u, x)
if d < dbest:
ybest = y
xbest = x
dbest = d
return ybest
Where, npdistance
simply gives distance between two points ie
def npdistance(x1, x2):
return(np.sum((x1-x2)**2))
I want to optimize npnearest
by performing nearest neighbor search directly in numpy
. This means that the function cannot use for/while
loops.
Thanks
Since you don't need to use that exact function, you can simply change the sum to work over a particular axis. This will return a new list with the calculations and you can call argmin
to get the index of the minimum value. Use that and lookup your label:
import numpy as np
def npdistance_idx(x1, x2):
return np.argmin(np.sum((x1-x2)**2, axis=1))
Y = ["label 0", "label 1", "label 2", "label 3"]
u = np.array([[1, 5.5]])
X = np.array([[1,2], [1, 5], [0, 0], [7, 7]])
idx = npdistance_idx(X, u)
print(Y[idx]) # label 1
Numpy supports vectorized operations ( broadcasting )
This means you can pass in arrays and operations will be applied to entire arrays in an optimized way (SIMD - single instruction, multiple data)
You can then get the address of the array minimum using .argmin()
Hope this helps
In [9]: numbers = np.arange(10); numbers
Out[9]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
In [10]: numbers -= 5; numbers
Out[10]: array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4])
In [11]: numbers = np.power(numbers, 2); numbers
Out[11]: array([25, 16, 9, 4, 1, 0, 1, 4, 9, 16])
In [12]: numbers.argmin()
Out[12]: 5
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.