簡體   English   中英

計算python中每個點之間的距離的最快方法

[英]Fastest way to compute distance beetween each points in python

在我的項目中,我需要計算存儲在數組中的每個點之間的歐氏距離。 入口數組是一個二維的numpy數組,具有3列,分別是坐標(x,y,z),每行定義一個新點。

我通常在測試用例中使用5000-6000分。

我的第一個算法使用Cython和第二個Numpy。 我發現我的numpy算法比cython更快。

編輯:6000點:

numpy 1.76秒/ cython 4.36秒

這是我的cython代碼:

cimport cython
from libc.math cimport sqrt
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void calcul1(double[::1] M,double[::1] R):

  cdef int i=0
  cdef int max = M.shape[0]
  cdef int x,y
  cdef int start = 1

  for x in range(0,max,3):
     for y in range(start,max,3):

        R[i]= sqrt((M[y] - M[x])**2 + (M[y+1] - M[x+1])**2 + (M[y+2] - M[x+2])**2)
        i+=1  

     start += 1

M是初始條目數組的內存視圖,但在調用函數calcul1()之前由numpy flatten() calcul1() ,R是用於存儲所有結果的一維輸出數組的內存視圖。

這是我的Numpy代碼:

def calcul2(M):

     return np.sqrt(((M[:,:,np.newaxis] - M[:,np.newaxis,:])**2).sum(axis=0))

這里M是初始條目數組,但是在函數調用之前以numpy進行transpose() ,以將坐標(x,y,z)作為行並將點作為列。

而且,此numpy函數非常方便,因為它返回的數組組織良好。 這是一個n個點數為n的by n數組,每個點都有一行和一列。 因此,例如,距離AB存儲在行A和列B的交點索引處。

這是我如何稱呼它們(cython函數):

cpdef test():

  cdef double[::1] Mf 
  cdef double[::1] out = np.empty(17998000,dtype=np.float64) # (6000² - 6000) / 2

  M = np.arange(6000*3,dtype=np.float64).reshape(6000,3) # Example array with 6000 points
  Mf = M.flatten() #because my cython algorithm need a 1D array
  Mt = M.transpose() # because my numpy algorithm need coordinates as rows

  calcul2(Mt)

  calcul1(Mf,out)

我在這里做錯什么了嗎? 對於我的項目來說,兩者都不夠快。

1:有沒有辦法改善我的cython代碼以擊敗numpy的速度?

2:是否可以改善numpy代碼以使其計算更快?

3:或其他解決方案,但必須是python / cython(如並行計算)?

謝謝。

不知道在哪里獲取時間,但是可以使用scipy.spatial.distance

M = np.arange(6000*3, dtype=np.float64).reshape(6000,3)
np_result = calcul2(M)
sp_result = sd.cdist(M.T, M.T) #Scipy usage
np.allclose(np_result, sp_result)
>>> True

時間:

%timeit calcul2(M)
1000 loops, best of 3: 313 µs per loop

%timeit sd.cdist(M.T, M.T)
10000 loops, best of 3: 86.4 µs per loop

重要的是,它對於了解輸出是對稱的也很有用:

np.allclose(sp_result, sp_result.T)
>>> True

一種替代方法是僅計算此數組的上三角:

%timeit sd.pdist(M.T)
10000 loops, best of 3: 39.1 µs per loop

編輯:不確定要壓縮哪個索引,看起來您可能同時做這兩種方式? 壓縮其他索引以進行比較:

%timeit sd.pdist(M)
10 loops, best of 3: 135 ms per loop

仍比您當前的NumPy實現快約10倍。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM