[英]Cython - efficiently filtering a typed memoryview
此Cython函數返回numpy數組中位於一定范圍內的元素中的隨機元素:
cdef int search(np.ndarray[int] pool):
cdef np.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
這樣很好。 但是,此功能對於我的代碼的性能非常關鍵。 類型化的內存視圖顯然比numpy數組快得多,但是不能以與上述相同的方式對其進行過濾。
我該如何使用鍵入的memoryviews編寫一個與上述功能相同的函數? 還是有另一種方法來改善功能的性能?
好的,讓我們開始使代碼更通用,稍后再討論性能。
我通常不使用:
import numpy as np
cimport numpy as np
我個人喜歡為cimport
包使用不同的名稱,因為它有助於使C端和NumPy-Python端保持分開。 所以對於這個答案,我將使用
import numpy as np
cimport numpy as cnp
另外,我將對該函數進行lower_limit
和upper_limit
參數。 也許在您的情況下是靜態(或全局)定義的,但是這使示例更加獨立。 因此,起點是您的代碼的稍作修改的版本:
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
Cython的一個非常不錯的功能是融合類型 ,因此您可以輕松地將此功能推廣為不同類型。 您的方法僅適用於32位整數數組(至少如果int
在您的計算機上為32位)。 支持更多的數組類型非常容易:
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
當然,您可以根據需要添加更多類型。 優點是新版本可以在舊版本失敗的地方工作:
>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0
現在,更籠統地說,讓我們看一下您的函數實際執行的操作:
為什么要創建這么多數組? 我的意思是,你可以簡單地計算范圍內有多少個元素都在,取0之間的范圍內的元素數量的隨機整數,然后采取任何元素將是這個結果數組中的索引。
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
# Count the number of elements that are within the limits
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
# Take a random index
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
# Go through the array again and take the element at the random index that
# is within the bounds
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
它不會快很多,但是會節省很多內存。 而且因為沒有中間數組,所以根本不需要內存視圖-但是如果願意,可以將參數列表中的cnp.ndarray[int_or_float] arr
替換為int_or_float[:]
甚至int_or_float[::1] arr
並在memoryview上運行(它可能不會更快,但也不會很慢)。
與Cython相比,我通常更喜歡numba(至少在我使用它的情況下),因此讓我們將其與該代碼的numba版本進行比較:
import numba as nb
import numpy as np
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
還有一個numexpr
變體:
import numexpr
np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
好的,讓我們做一個基准測試:
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
因此,通過不使用中間數組,您的速度大約可以提高5倍,而如果使用numba,則可能會增加5倍(好像我在那里缺少一些可能的Cython優化,numba通常快約2倍,或者是Cython一樣快。 )。 因此,使用numba解決方案可以使它快20倍左右。
numexpr
在這里不是真正可比的,主要是因為您不能在此處使用布爾數組索引。
差異將取決於數組的內容和限制。 您還必須衡量應用程序的性能。
np.random.choice
:如果下限和上限通常不改變,最快的解決方案是過濾一次數組,然后對它多次調用np.random.choice
。 那可能快幾個數量級 。
lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]
def search_cached():
return np.random.choice(filtered_array)
%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
因此幾乎快了1000倍,完全不需要Cython或numba。 但這是一種特殊情況,可能對您沒有用。
如果您想自己進行基准測試,請在此處(基於Jupyter Notebook / Lab,因此為%
符號):
%load_ext cython
%%cython
cimport numpy as cnp
import numpy as np
cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
cdef cnp.ndarray[int] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
ctypedef fused int_or_float:
cnp.int32_t
cnp.int64_t
cnp.float32_t
cnp.float64_t
cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
cdef cnp.ndarray[int_or_float] limited
limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
return np.random.choice(limited)
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
cdef int_or_float element
cdef Py_ssize_t num_valid = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
num_valid += 1
cdef Py_ssize_t random_index = np.random.randint(0, num_valid)
cdef Py_ssize_t clamped_index = 0
for index in range(arr.shape[0]):
element = arr[index]
if lower_bound <= element <= upper_bound:
if clamped_index == random_index:
return element
clamped_index += 1
import numexpr
import numba as nb
import numpy as np
def search_numexpr(arr, l, u):
return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
@nb.njit
def search_numba(arr, lower, upper):
num_valids = 0
for item in arr:
if item >= lower and item <= upper:
num_valids += 1
random_index = np.random.randint(0, num_valids)
valid_index = 0
for item in arr:
if item >= lower and item <= upper:
if valid_index == random_index:
return item
valid_index += 1
from simple_benchmark import benchmark, MultiArgument
arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]
b = benchmark(funcs, arguments, argument_name='array size')
%matplotlib widget
import matplotlib.pyplot as plt
plt.style.use('ggplot')
b.plot()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.