簡體   English   中英

使用條件(numpy.where)時更快的 numpy 數組索引?

[英]Faster numpy array indexing when using condition (numpy.where)?

我有一個巨大的 numpy 數組,形狀為 (50000000, 3),我正在使用:

x = array[np.where((array[:,0] == value) | (array[:,1] == value))]

得到我想要的數組部分。 但是這種方式似乎很慢。 有沒有更有效的方法來使用 numpy 執行相同的任務?

np.where是高度優化的,我懷疑有人能寫出比上一個 Numpy 版本中實現的代碼更快的代碼(免責聲明:我是優化它的人)。 也就是說,這里的主要問題不是np.where而是創建臨時 boolean 數組的條件。 不幸的是,這是在 Numpy 中執行此操作的方法,只要您僅使用具有相同輸入布局的 Numpy,就沒什么可做的。

解釋為什么它不是很有效的原因之一是輸入數據布局效率低下 實際上,假設array使用默認的行主要順序連續存儲在 memory 中, array[:,0] == value將讀取 memory 中數組的每 3 個項目中的 1 個項目。由於 CPU 緩存的工作方式(即緩存行) ,預取等),浪費了 memory 帶寬的 2/3 事實上,output boolean 數組也需要寫入,並且由於頁面錯誤,填充新創建的數組有點慢。 請注意, array[:,1] == value肯定會由於輸入的大小(無法容納在大多數 CPU 緩存中)而從 RAM 中重新加載數據 RAM 很慢,與 CPU 和緩存的計算速度相比,它越來越慢。 這個問題稱為“ memory 牆”,幾十年前就已經出現,預計不會很快得到修復。 另請注意,邏輯或還將創建一個從 RAM 讀/寫到 RAM 的新數組。 更好的數據布局是在 memory 中連續的(3, 50000000)轉置數組(注意np.transpose不會產生連續數組)。

解釋性能問題的另一個原因是Numpy 往往未針對非常小的軸進行優化

一個主要的解決方案是在可能的情況下以轉置的方式創建輸入。 另一種解決方案是編寫Numba 或 Cython 代碼 這是非轉置輸入的實現:

# Compilation for the most frequent types. 
# Please pick the right ones so to speed up the compilation time. 
@nb.njit(['(uint8[:,::1],uint8)', '(int32[:,::1],int32)', '(int64[:,::1],int64)', '(float64[:,::1],float64)'], parallel=True)
def select(array, value):
    n = array.shape[0]
    mask = np.empty(n, dtype=np.bool_)
    for i in nb.prange(n):
        mask[i] = array[i, 0] == value or array[i, 1] == value
    return mask

x = array[select(array, value)]

請注意,我使用了並行實現,因為or運算符對於 Numba 不是最優的(唯一的解決方案似乎是使用本機代碼或 Cython),而且因為在某些平台(如計算服務器)上,RAM 不能完全被一個線程飽和。 另請注意,對於 select 的結果,使用array[np.where(select(array, value))[0]]select 事實上,如果結果是隨機的或非常小,那么np.where可以更快,因為它對 boolean 索引不執行的這些情況進行了特殊優化。 請注意, np.where在 Numba function 的上下文中沒有特別優化,因為 Numba 使用它自己的 Numpy 函數實現,並且它們有時沒有針對大型 arrays 進行優化。更快的實現包括並行創建x但這不是使用 Numba 很簡單,因為提前不知道 output 項目的數量,並且線程必須知道在哪里寫入數據,更不用說 Numpy 已經相當快地按順序執行此操作,只要 output 是可預測的。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM