[英]Faster median in very large numpy arrays
我有一個非常大的numpy數組,其維度為(4000,6000,15)。
我現在想要每個堆棧的中值,即沿第三維。 當前代碼可以工作,但是奇怪的是速度很慢,單個堆棧的中位數[0,0,:](15個值)至少需要半秒左右的時間才能完成。
height = 4000
width = 6000
N = 15
poolmedian = np.zeros((height,width,3))
RGBmedian = np.zeros((height,width,N), dtype=float)
for n in range(0,height):
for m in range(0,width):
poolmedian[n,m,0] = np.median(RGBmedian[n,m,:])
您將要盡可能向量化中值計算。 每次調用numpy
函數時,都會在C和Python層之間來回移動。 在C層中盡可能多地執行以下操作:
import numpy as np
height = 40
width = 60
N = 15
np.random.seed(1)
poolmedian = np.zeros((height,width,3))
RGBmedian = np.random.random((height,width,N))
def original():
for n in range(0,height):
for m in range(0,width):
poolmedian[n,m,0] = np.median(RGBmedian[n,m,:])
return poolmedian
def vectorized():
# Note: np.median is only called ONCE, not n*m times.
poolmedian[:, :, 0] = np.median(RGBmedian, axis=-1)
return poolmedian
orig = original()
vec = vectorized()
np.testing.assert_array_equal(orig, vec)
您可以看到,自斷言通過以來,值是相同的(盡管不清楚為什么在poolmedian
需要3個dims)。 我將上面的代碼放在一個名為test.py的文件中,並使用IPython來方便使用%timeit
。 我也略微減小了大小,以使其運行更快,但是您應該在大數據上獲得類似的節省。 向量化版本的速度提高了約100倍:
In [1]: from test import original, vectorized
In [2]: %timeit original()
69.1 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [3]: %timeit vectorized()
618 µs ± 4.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
通常,您要使用numpy
的廣播規則並盡可能少地調用一個函數。 如果您正在尋找高性能的numpy
代碼,則在循環中調用函數幾乎總是 numpy
。
附錄:
我已經在test.py中添加了以下函數,因為還有另一個答案,所以我想說明一下,調用完全矢量化的版本(即無循環)會更快,並且還可以修改為使用4000到6000的dims :
import numpy as np
height = 4000
width = 6000
N = 15
...
def fordy():
for n in range(0,height):
for m in range(0,width):
array = RGBmedian[n,m,:]
array.sort()
poolmedian[n, m, 0] = (array[6] + array[7])/2
return poolmedian
如果將所有這些都加載到IPython中,則會得到:
In [1]: from test import original, fordy, vectorized
In [2]: %timeit original()
6.87 s ± 72.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [3]: %timeit fordy()
262 ms ± 737 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [4]: %timeit vectorized()
18.4 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
HTH。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.