簡體   English   中英

Python如何提高numpy數組的性能?

[英]Python How to improve numpy array performance?

我有一個全局numpy.array 數據 ,它是一個200 * 200 * 3 3d數組,在3d空間中包含40000個點。

我的目標是計算每個點到單位立方體四個角的距離((0,0,0),(1,0,0),(0,1,0),(0,0,1) ),因此我可以確定哪個角離它最近。

def dist(*point):
    return np.linalg.norm(data - np.array(rgb), axis=2)

buffer = np.stack([dist(0, 0, 0), dist(1, 0, 0), dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0)

我編寫了這段代碼並對其進行了測試,每次運行大約花費10毫秒。 我的問題是如何改善這段代碼的性能,最好在不到1ms的時間內運行。

您可以使用Scipy cdist

# unit cube coordinates as array
uc = np.array([[0, 0, 0],[1, 0, 0], [0, 1, 0], [0, 0, 1]])

# buffer output
buf = cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)

運行時測試

# Original approach
def org_app():
    return np.stack([dist(0, 0, 0), dist(1, 0, 0), \
       dist(0, 1, 0), dist(0, 0, 1)]).argmin(axis=0)

時間-

In [170]: data = np.random.rand(200,200,3)

In [171]: %timeit org_app()
100 loops, best of 3: 4.24 ms per loop

In [172]: %timeit cdist(data.reshape(-1,3), uc).argmin(1).reshape(data.shape[0],-1)
1000 loops, best of 3: 1.25 ms per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM