簡體   English   中英

Numba jitted len() 比純 Python len() 慢

[英]Numba jitted len() is slower than pure Python len()

我正在學習 numba 並遇到了這種我不理解的“奇怪”行為。 我嘗試使用以下代碼(在 iPython 中,用於計時):

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

def py_len(seq):
    return len(seq)

##
t = np.random.rand(1000)

%timeit nb_len(t)
%timeit py_len(t)

結果如下(實際上是numba編譯導致的第二次運行):

258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

純 python 版本是 numba 版本的兩倍。 我也嘗試過簽名@nb.njit( nb.int32(nb.float64[:]) )但結果還是一樣。

我在某處弄錯了嗎?

謝謝你。

增加時間的不是 len() 部分。 使用輸入參數調用 jit 函數會增加開銷,這就是您看到的時間差。

import numba as nb

def py_pass(i):
    return i

@nb.njit()
def nb_pass(i):
    return i

%timeit py_pass(1)
%timeit nb_pass(1)

帶輸入參數的結果

102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

有趣的是,如果你不需要向 jit 函數傳遞任何東西,它會更快:

def py_pass():
    return 1

@nb.njit()
def nb_pass():
    return 1

%timeit py_pass()
%timeit nb_pass()

沒有輸入參數的結果

96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

正如另一個答案所述,在這種情況下不是因為len函數,而是因為對 numba 函數的調用實際上比對普通 Python 函數的調用慢。

是什么讓jit -ted 函數與眾不同?

要理解為什么調用 numba jitted 函數較慢,必須了解 numba jitted 函數不再是函數。 這是一個調度程序對象:

import numba as nb
@nb.njit
def nb_len(seq):
    return len(seq)
print(nb_len)  # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)

這個CPUDispatcher實例代表(可能)多個基於裝飾函數生成的編譯函數。

這意味着當您調用CPUDispatcher實例時,有多個步驟:

  • 獲取參數的類型。
  • 如果這些類型的參數沒有合適的編譯函數,則使用參數類型編譯裝飾函數。
  • 有時:將參數轉換為相應的 numba 類型。
  • 調用編譯后的函數。

與未修飾的函數相比,所有這些步驟都會增加開銷。 特別是如果沒有合適的編譯函數並且調度程序需要編譯函數 - 或者 - 輸入類型需要轉換(只發生在 Python 類型,如:列表、集合、字典)調用CPUDispatcher會慢很多 - 這些類型正在在編寫 numba 0.46 時已棄用,部分原因是,請參閱“2.11.2. 不推薦使用 List 和 Set 類型的反射”

在你的情況下

在您的情況下,由於編譯,對 jitted 函數的第一次調用將明顯變慢。

任何后續調用只會稍微慢一點,因為 numba 必須獲取參數類型,檢查是否已經存在已編譯的函數,然后調用該已編譯函數。 有趣的是,額外的時間取決於參數的數量和該函數已經編譯的“重載”的數量。 通常這個額外的時間是微不足道的,因為該函數所做的不僅僅是調用len

編譯時間

盡管該函數非常簡單,但第一次調用時的編譯需要大量時間:

import numpy as np
import numba as nb

def first_call(seq):
    @nb.njit
    def nb_len(seq):
        return len(seq)
    return nb_len(seq)

@nb.njit
def _nb_len(seq):
    return len(seq)

def subsequent_calls(seq):
    return _nb_len(seq)

t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))

%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

轉換時間

此外,如果 numba 需要轉換參數,它會慢很多。 這僅發生在 numba 無法直接處理的 Python 類型中,例如列表:

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

arr = np.random.rand(10_000)
lst = arr.tolist()

nb_len(arr)
nb_len(lst)

%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

概括

  • 與普通的 Python 函數相比,Numba 函數有一些額外的開銷。 因此,請確保您做了 numba 擅長優化的“足夠”的事情,否則普通的 Python 函數將更快、更靈活且更易於調試。
  • numba 函數中的函數調用確實與 numba 函數之外的函數調用不同。 所以len()nb_lenlen()py_len可以有完全不同的運行時間。 然而,在這種情況下,運行時間幾乎相同。 但是,意識到這一點通常是件好事。
  • 根據參數類型,numba 函數可能(在幕后)非常慢,尤其是在將 Python 類型作為參數或返回類型處理時!

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM