[英]Numba jitted len() is slower than pure Python len()
我正在學習 numba 並遇到了這種我不理解的“奇怪”行為。 我嘗試使用以下代碼(在 iPython 中,用於計時):
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
def py_len(seq):
return len(seq)
##
t = np.random.rand(1000)
%timeit nb_len(t)
%timeit py_len(t)
結果如下(實際上是numba編譯導致的第二次運行):
258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
純 python 版本是 numba 版本的兩倍。 我也嘗試過簽名@nb.njit( nb.int32(nb.float64[:]) )
但結果還是一樣。
我在某處弄錯了嗎?
謝謝你。
增加時間的不是 len() 部分。 使用輸入參數調用 jit 函數會增加開銷,這就是您看到的時間差。
import numba as nb
def py_pass(i):
return i
@nb.njit()
def nb_pass(i):
return i
%timeit py_pass(1)
%timeit nb_pass(1)
102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
有趣的是,如果你不需要向 jit 函數傳遞任何東西,它會更快:
def py_pass():
return 1
@nb.njit()
def nb_pass():
return 1
%timeit py_pass()
%timeit nb_pass()
96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
正如另一個答案所述,在這種情況下不是因為len
函數,而是因為對 numba 函數的調用實際上比對普通 Python 函數的調用慢。
jit
-ted 函數與眾不同?要理解為什么調用 numba jitted 函數較慢,必須了解 numba jitted 函數不再是函數。 這是一個調度程序對象:
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
print(nb_len) # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)
這個CPUDispatcher
實例代表(可能)多個基於裝飾函數生成的編譯函數。
這意味着當您調用CPUDispatcher
實例時,有多個步驟:
與未修飾的函數相比,所有這些步驟都會增加開銷。 特別是如果沒有合適的編譯函數並且調度程序需要編譯函數 - 或者 - 輸入類型需要轉換(只發生在 Python 類型,如:列表、集合、字典)調用CPUDispatcher
會慢很多 - 這些類型正在在編寫 numba 0.46 時已棄用,部分原因是,請參閱“2.11.2. 不推薦使用 List 和 Set 類型的反射” 。
在您的情況下,由於編譯,對 jitted 函數的第一次調用將明顯變慢。
任何后續調用只會稍微慢一點,因為 numba 必須獲取參數類型,檢查是否已經存在已編譯的函數,然后調用該已編譯函數。 有趣的是,額外的時間取決於參數的數量和該函數已經編譯的“重載”的數量。 通常這個額外的時間是微不足道的,因為該函數所做的不僅僅是調用len
。
盡管該函數非常簡單,但第一次調用時的編譯需要大量時間:
import numpy as np
import numba as nb
def first_call(seq):
@nb.njit
def nb_len(seq):
return len(seq)
return nb_len(seq)
@nb.njit
def _nb_len(seq):
return len(seq)
def subsequent_calls(seq):
return _nb_len(seq)
t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))
%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
此外,如果 numba 需要轉換參數,它會慢很多。 這僅發生在 numba 無法直接處理的 Python 類型中,例如列表:
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
arr = np.random.rand(10_000)
lst = arr.tolist()
nb_len(arr)
nb_len(lst)
%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
len()
在nb_len
和len()
在py_len
可以有完全不同的運行時間。 然而,在這種情況下,運行時間幾乎相同。 但是,意識到這一點通常是件好事。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.