[英]Numba JIT slower than pure python with parameterized function
我剛剛寫了一個比較 Numba 和 Julia 的 簡單基准測試,以及一些討論。
我想知道我的 Numba 代碼是否可以以某種方式修復,或者 Numba 是否確實不支持我正在嘗試做的事情。
我們的想法是使用 JIT 編譯的正交規則來評估這個 function。
g(p) = integrate exp(p*x) with respect to x
這是簡單的正交 function:
@nb.njit
def quad_trap(f,a,b,N):
h = (b-a)/N
integral = h * ( f(a) + f(b) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk)
return integral
我可以將 JIT 編譯的 function 傳遞給這個 function,就像這個:
@nb.njit(nb.float64(nb.float64))
def func(x):
return math.exp(x) - 10
這比純 Python 快 10-20 倍左右,相當不錯。
現在,我想做的是傳遞 x 的 function 並由 p 參數化,類似於:
def g(p):
@nb.njit(nb.float64(nb.float64))
def integrand(x):
return math.exp(p*x) - 10
return quad_trap(integrand, -1, 1, 10000)
這樣做似乎會破壞 Numba,即使與純 Python 相比,它也會變得異常緩慢。
我做錯了什么,還是 Numba 確實不支持此功能? (我確實檢查了文檔,但我不明白問題出在哪里)。 謝謝!
TL;DR: Numba 似乎還不支持此功能。
這比純 Python 快 10-20 倍左右,非常好。
Numba function quad_trap
將在您第一次調用時編譯。 如果參數的類型發生變化,那么 Numba 將再次重新編譯 function。 編譯時間通常遠不能忽略(幾毫秒到幾秒)。 為了避免這種情況,解決方案通常是指定參數的類型。 但是,AFAIK,由於 function,這在此處是不可能的(至少沒有記錄在案)。 That being said, because you certainly benchmark the quad_trap
function with the same function, Numba should not recompile the function because the type of the provided arguments does not change.
這樣做似乎會破壞 Numba,即使與純 Python 相比,它也會變得非常慢。
在 Numba 的最新版本中,它可以在沒有警告的情況下工作,但他的原因是 function integrand
被一遍又一遍地重新編譯,因為 Numba 不知道其代碼是否更改(或在此函數中遞歸調用的一個函數/運算符)。 在舊版本中,Numba 可能會抱怨 function integrand
讀取的參數p
是從包含 function 的父級讀取的。 這稱為閉包。
編譯器通常不太支持閉包,因為處理它們要困難得多(它們需要從父函數的堆棧中讀取變量)。 一個經常出現的普遍問題是閉包可以逃脫其父 function 的 scope 並在外部調用導致未定義的行為(因為閉包將嘗試讀取已完成函數的失效堆棧)。
一個技巧是將@nb.njit
裝飾器從integrand
移動到g
但 Numba 拒絕編譯g
因為它不支持可能逃脫其父 function 的 scope 的閉包(由於前面描述的問題)。 Note that the closure does not escape the function where it is defined in your case but Numba cannot prove that (since the quad_trap
function is already compiled) and it also unfortunately fails to do that when the function quad_trap
is inlined (while it could theoretically prove這是安全的)。 事實上, 文件指出:
Numba 現在支持內部函數,只要它們是非遞歸的並且僅在本地調用,但不作為參數傳遞或作為結果返回。 還支持在內部 function 中使用閉包變量(在外部范圍中定義的變量)。
我認為@generated_jit
裝飾器可能有助於解決此類問題,但我沒有成功使其適用於您的特定情況。 它至少應該有助於在定義時(如integrand
)而不是在第一次調用期間編譯g
。
一種解決方案是不使用閉包:
@nb.njit
def quad_trap_p(f,a,b,N,p):
h = (b-a)/N
integral = h * ( f(a,p) + f(b,p) ) / 2
for k in range(N):
xk = (b-a) * k/N + a
integral = integral + h*f(xk,p)
return integral
@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
return math.exp(p*x) - 10
def g(p):
return quad_trap_p(integrand, -1, 1, 10000, p)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.