簡體   English   中英

Numba JIT 比帶有參數化 function 的純 python 慢

[英]Numba JIT slower than pure python with parameterized function

我剛剛寫了一個比較 Numba 和 Julia 的 簡單基准測試,以及一些討論。

我想知道我的 Numba 代碼是否可以以某種方式修復,或者 Numba 是否確實不支持我正在嘗試做的事情。

我們的想法是使用 JIT 編譯的正交規則來評估這個 function。

g(p) = integrate exp(p*x) with respect to x

這是簡單的正交 function:

@nb.njit   
def quad_trap(f,a,b,N):
    h = (b-a)/N
    integral = h * ( f(a) + f(b) ) / 2
    for k in range(N):
        xk = (b-a) * k/N + a
        integral = integral + h*f(xk)
    return integral

我可以將 JIT 編譯的 function 傳遞給這個 function,就像這個:

@nb.njit(nb.float64(nb.float64))
def func(x):
    return math.exp(x) - 10

這比純 Python 快 10-20 倍左右,相當不錯。

現在,我想做的是傳遞 x 的 function 並由 p 參數化,類似於:

def g(p): 
    @nb.njit(nb.float64(nb.float64))
    def integrand(x):
        return math.exp(p*x) - 10
    return quad_trap(integrand, -1, 1, 10000) 

這樣做似乎會破壞 Numba,即使與純 Python 相比,它也會變得異常緩慢。

我做錯了什么,還是 Numba 確實不支持此功能? (我確實檢查了文檔,但我不明白問題出在哪里)。 謝謝!

TL;DR: Numba 似乎還不支持此功能。

這比純 Python 快 10-20 倍左右,非常好。

Numba function quad_trap將在您第一次調用時編譯。 如果參數的類型發生變化,那么 Numba 將再次重新編譯 function。 編譯時間通常遠不能忽略(幾毫秒到幾秒)。 為了避免這種情況,解決方案通常是指定參數的類型。 但是,AFAIK,由於 function,這在此處是不可能的(至少沒有記錄在案)。 That being said, because you certainly benchmark the quad_trap function with the same function, Numba should not recompile the function because the type of the provided arguments does not change.

這樣做似乎會破壞 Numba,即使與純 Python 相比,它也會變得非常慢。

在 Numba 的最新版本中,它可以在沒有警告的情況下工作,但他的原因是 function integrand被一遍又一遍地重新編譯,因為 Numba 不知道其代碼是否更改(或在此函數中遞歸調用的一個函數/運算符)。 在舊版本中,Numba 可能會抱怨 function integrand讀取的參數p是從包含 function 的父級讀取的。 這稱為閉包。

編譯器通常不太支持閉包,因為處理它們要困難得多(它們需要從父函數的堆棧中讀取變量)。 一個經常出現的普遍問題是閉包可以逃脫其父 function 的 scope 並在外部調用導致未定義的行為(因為閉包將嘗試讀取已完成函數的失效堆棧)。

一個技巧是將@nb.njit裝飾器從integrand移動到g但 Numba 拒絕編譯g因為它不支持可能逃脫其父 function 的 scope 的閉包(由於前面描述的問題)。 Note that the closure does not escape the function where it is defined in your case but Numba cannot prove that (since the quad_trap function is already compiled) and it also unfortunately fails to do that when the function quad_trap is inlined (while it could theoretically prove這是安全的)。 事實上, 文件指出:

Numba 現在支持內部函數,只要它們是非遞歸的並且僅在本地調用,但不作為參數傳遞或作為結果返回。 還支持在內部 function 中使用閉包變量(在外部范圍中定義的變量)。

我認為@generated_jit裝飾器可能有助於解決此類問題,但我沒有成功使其適用於您的特定情況。 它至少應該有助於在定義時(如integrand )而不是在第一次調用期間編譯g

一種解決方案是不使用閉包:

@nb.njit
def quad_trap_p(f,a,b,N,p):
    h = (b-a)/N
    integral = h * ( f(a,p) + f(b,p) ) / 2
    for k in range(N):
        xk = (b-a) * k/N + a
        integral = integral + h*f(xk,p)
    return integral

@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
    return math.exp(p*x) - 10

def g(p):
    return quad_trap_p(integrand, -1, 1, 10000, p)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM