簡體   English   中英

Cython-有效過濾類型化的內存視圖

[英]Cython - efficiently filtering a typed memoryview

此Cython函數返回numpy數組中位於一定范圍內的元素中的隨機元素:

cdef int search(np.ndarray[int] pool):
  cdef np.ndarray[int] limited
  limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
  return np.random.choice(limited)

這樣很好。 但是,此功能對於我的代碼的性能非常關鍵。 類型化的內存視圖顯然比numpy數組快得多,但是不能以與上述相同的方式對其進行過濾。

我該如何使用鍵入的memoryviews編寫一個與上述功能相同的函數? 還是有另一種方法來改善功能的性能?

好的,讓我們開始使代碼更通用,稍后再討論性能。

我通常不使用:

import numpy as np
cimport numpy as np

我個人喜歡為cimport包使用不同的名稱,因為它有助於使C端和NumPy-Python端保持分開。 所以對於這個答案,我將使用

import numpy as np
cimport numpy as cnp

另外,我將對該函數進行lower_limitupper_limit參數。 也許在您的情況下是靜態(或全局)定義的,但是這使示例更加獨立。 因此,起點是您的代碼的稍作修改的版本:

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
    cdef cnp.ndarray[int] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

Cython的一個非常不錯的功能是融合類型 ,因此您可以輕松地將此功能推廣為不同類型。 您的方法僅適用於32位整數數組(至少如果int在您的計算機上為32位)。 支持更多的數組類型非常容易:

ctypedef fused int_or_float:
    cnp.int32_t
    cnp.int64_t
    cnp.float32_t
    cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
    cdef cnp.ndarray[int_or_float] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

當然,您可以根據需要添加更多類型。 優點是新版本可以在舊版本失敗的地方工作:

>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0

現在,更籠統地說,讓我們看一下您的函數實際執行的操作:

  • 您創建一個布爾數組,其中元素高於下限
  • 您創建一個布爾數組,其中元素低於上限
  • 您可以按位和兩個布爾數組中的一個來創建布爾數組。
  • 您創建一個僅包含布爾掩碼為true的元素的新數組
  • 您只從最后一個數組中提取一個元素

為什么要創建這么多數組? 我的意思是,你可以簡單地計算范圍內有多少個元素都在,取0之間的范圍內的元素數量的隨機整數,然后采取任何元素將是這個結果數組中的索引。

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
    cdef int_or_float element

    # Count the number of elements that are within the limits
    cdef Py_ssize_t num_valid = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            num_valid += 1

    # Take a random index
    cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

    # Go through the array again and take the element at the random index that
    # is within the bounds
    cdef Py_ssize_t clamped_index = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            if clamped_index == random_index:
                return element
            clamped_index += 1

它不會快很多,但是會節省很多內存。 而且因為沒有中間數組,所以根本不需要內存視圖-但是如果願意,可以將參數列表中的cnp.ndarray[int_or_float] arr替換為int_or_float[:]甚至int_or_float[::1] arr並在memoryview上運行(它可能不會更快,但也不會很慢)。

與Cython相比,我通常更喜歡numba(至少在我使用它的情況下),因此讓我們將其與該代碼的numba版本進行比較:

import numba as nb
import numpy as np

@nb.njit
def search_numba(arr, lower, upper):
    num_valids = 0
    for item in arr:
        if item >= lower and item <= upper:
            num_valids += 1

    random_index = np.random.randint(0, num_valids)

    valid_index = 0
    for item in arr:
        if item >= lower and item <= upper:
            if valid_index == random_index:
                return item
            valid_index += 1

還有一個numexpr變體:

import numexpr

np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])

好的,讓我們做一個基准測試:

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')

在此處輸入圖片說明

因此,通過不使用中間數組,您的速度大約可以提高5倍,而如果使用numba,則可能會增加5倍(好像我在那里缺少一些可能的Cython優化,numba通常快約2倍,或者是Cython一樣快。 )。 因此,使用numba解決方案可以使它快20倍左右。

numexpr在這里不是真正可比的,主要是因為您不能在此處使用布爾數組索引。

差異將取決於數組的內容和限制。 您還必須衡量應用程序的性能。


np.random.choice :如果下限和上限通常不改變,最快的解決方案是過濾一次數組,然后對它多次調用np.random.choice 那可能快幾個數量級

lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]

def search_cached():
    return np.random.choice(filtered_array)

%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

因此幾乎快了1000倍,完全不需要Cython或numba。 但這是一種特殊情況,可能對您沒有用。


如果您想自己進行基准測試,請在此處(基於Jupyter Notebook / Lab,因此為%符號):

%load_ext cython

%%cython

cimport numpy as cnp
import numpy as np

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
    cdef cnp.ndarray[int] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

ctypedef fused int_or_float:
    cnp.int32_t
    cnp.int64_t
    cnp.float32_t
    cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
    cdef cnp.ndarray[int_or_float] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
    cdef int_or_float element
    cdef Py_ssize_t num_valid = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            num_valid += 1

    cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

    cdef Py_ssize_t clamped_index = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            if clamped_index == random_index:
                return element
            clamped_index += 1

import numexpr
import numba as nb
import numpy as np

def search_numexpr(arr, l, u):
    return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])

@nb.njit
def search_numba(arr, lower, upper):
    num_valids = 0
    for item in arr:
        if item >= lower and item <= upper:
            num_valids += 1

    random_index = np.random.randint(0, num_valids)

    valid_index = 0
    for item in arr:
        if item >= lower and item <= upper:
            if valid_index == random_index:
                return item
            valid_index += 1

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')

%matplotlib widget

import matplotlib.pyplot as plt

plt.style.use('ggplot')
b.plot()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM