[英]numpy faster than numba and cython , how to improve numba code
[英]Can my numba code be faster than numpy
我是 Numba 的新手,正在嘗試加快一些對 numpy 來說太笨重的計算。 我在下面給出的示例比較了一個包含我的計算子集的函數,該函數使用函數的矢量化/numpy 和 numba 版本,后者也通過注釋掉 @autojit 裝飾器作為純 python 進行了測試。
我發現 numba 和 numpy 版本相對於純 python 提供了相似的加速,這兩者都是大約 10 倍的速度提升。 numpy 版本實際上比我的 numba 函數稍快,但由於這種計算的 4D 性質,當 numpy 函數中的數組大小比這個玩具示例大得多時,我很快就會耗盡內存。
這種加速很好,但是當從純 python 移動到 numba 時,我經常在網絡上看到超過 100 倍的加速。
我想知道在 nopython 模式下移動到 numba 時是否有普遍的預期速度增加。 我還想知道我的 numba 化函數中是否有任何組件會限制進一步的速度提升。
import numpy as np
from timeit import default_timer as timer
from numba import autojit
import math
def vecRadCalcs(slope, skyz, solz, skya, sola):
nloc = len(slope)
ntime = len(solz)
[lenz, lena] = skyz.shape
asolz = np.tile(np.reshape(solz,[ntime,1,1,1]),[1,nloc,lenz,lena])
asola = np.tile(np.reshape(sola,[ntime,1,1,1]),[1,nloc,lenz,lena])
askyz = np.tile(np.reshape(skyz,[1,1,lenz,lena]),[ntime,nloc,1,1])
askya = np.tile(np.reshape(skya,[1,1,lenz,lena]),[ntime,nloc,1,1])
phi1 = np.cos(asolz)*np.cos(askyz)
phi2 = np.sin(asolz)*np.sin(askyz)*np.cos(askya- asola)
phi12 = phi1 + phi2
phi12[phi12> 1.0] = 1.0
phi = np.arccos(phi12)
return(phi)
@autojit
def RadCalcs(slope, skyz, solz, skya, sola, phi):
nloc = len(slope)
ntime = len(solz)
pop = 0.0
[lenz, lena] = skyz.shape
for iiT in range(ntime):
asolz = solz[iiT]
asola = sola[iiT]
for iL in range(nloc):
for iz in range(lenz):
for ia in range(lena):
askyz = skyz[iz,ia]
askya = skya[iz,ia]
phi1 = math.cos(asolz)*math.cos(askyz)
phi2 = math.sin(asolz)*math.sin(askyz)*math.cos(askya- asola)
phi12 = phi1 + phi2
if phi12 > 1.0:
phi12 = 1.0
phi[iz,ia] = math.acos(phi12)
pop = pop + 1
return(pop)
zenith_cells = 90
azim_cells = 360
nloc = 10 # nominallly ~ 700
ntim = 10 # nominallly ~ 200000
slope = np.random.rand(nloc) * 10.0
solz = np.random.rand(ntim) *np.pi/2.0
sola = np.random.rand(ntim) * 1.0*np.pi
base = np.ones([zenith_cells,azim_cells])
skya = np.deg2rad(np.cumsum(base,axis=1))
skyz = np.deg2rad(np.cumsum(base,axis=0)*90/zenith_cells)
phi = np.zeros(skyz.shape)
start = timer()
outcalc = RadCalcs(slope, skyz, solz, skya, sola, phi)
stop = timer()
outcalc2 = vecRadCalcs(slope, skyz, solz, skya, sola)
stopvec = timer()
print(outcalc)
print(stop-start)
print(stopvec-stop)
在我運行 numba 0.31.0 的機器上,Numba 版本比矢量化解決方案快 2 倍。 在對 numba 函數進行計時時,您需要多次運行該函數,因為第一次看到的是 jitting 代碼的時間 + 運行時間。 后續運行將不包括 jitting 函數時間的開銷,因為 Numba 將 jitted 代碼緩存在內存中。
另外,請注意您的函數不是計算相同的東西——您要小心,使用類似np.allclose
之類的結果比較相同的東西。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.