簡體   English   中英

如何從python中numpy.searchsorted的結果加快數組屏蔽的性能?

[英]How to speed up the performance of array masking from the results of numpy.searchsorted in python?

我想從 numpy.searchsorted() 的結果中生成一個掩碼:

import numpy as np

# generate test examples
x = np.random.rand(1000000)
y = np.random.rand(200)

# sort x
idx = np.argsort(x)
sorted_x = np.take_along_axis(x, idx, axis=-1)

# searchsort y in x
pt = np.searchsorted(sorted_x, y)

pt是一個數組。 然后我想創建一個大小為(200, 1000000)的布爾掩碼,當其索引為idx[0:pt[i]]時,它的值為 True ,我想出了一個像這樣的 for 循環:

mask = np.zeros((200, 1000000), dtype='bool')
for i in range(200):
     mask[i, idx[0:pt[i]]] = True

任何人都有加速for循環的想法?

方法#1

根據從OP's comments中提取的新信息,這些信息表明只有y實時變化,我們可以對x周圍的很多東西進行預處理,因此做得更好。 我們將創建一個散列數組來存儲階梯掩碼。 對於涉及y的部分,我們將簡單地使用從searchsorted獲得的索​​引索引到哈希數組,這將近似於最終的掩碼數組。 鑒於其參差不齊的性質,分配剩余 bool 的最后一步可以卸載到 numba。 如果我們決定擴大y的長度,這也應該是有益的。

我們來看看實現。

使用x預處理:

sidx = x.argsort()
ssidx = x.argsort().argsort()

# Choose a scale factor. 
# 1. A small one would store more mapping info, hence faster but occupy more mem
# 2. A big one would store less mapping info, hence slower, but memory efficient.
scale_factor = 100
mapar = np.arange(0,len(x),scale_factor)[:,None] > ssidx

y剩余步驟:

import numba as nb

@nb.njit(parallel=True,fastmath=True)
def array_masking3(out, starts, idx, sidx):
    N = len(out)
    for i in nb.prange(N):
        for j in nb.prange(starts[i], idx[i]):
            out[i,sidx[j]] = True
    return out

idx = np.searchsorted(x,y,sorter=sidx)
s0 = idx//scale_factor
starts = s0*scale_factor
out = mapar[s0]
out = array_masking3(out, starts, idx, sidx)

基准測試

In [2]: x = np.random.rand(1000000)
   ...: y = np.random.rand(200)

In [3]: ## Pre-processing step with "x"
   ...: sidx = x.argsort()
   ...: ssidx = x.argsort().argsort()
   ...: scale_factor = 100
   ...: mapar = np.arange(0,len(x),scale_factor)[:,None] > ssidx

In [4]: %%timeit
   ...: idx = np.searchsorted(x,y,sorter=sidx)
   ...: s0 = idx//scale_factor
   ...: starts = s0*scale_factor
   ...: out = mapar[s0]
   ...: out = array_masking3(out, starts, idx, sidx)
41 ms ± 141 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

 # A 1/10th smaller hashing array has similar timings
In [7]: scale_factor = 1000
   ...: mapar = np.arange(0,len(x),scale_factor)[:,None] > ssidx

In [8]: %%timeit
   ...: idx = np.searchsorted(x,y,sorter=sidx)
   ...: s0 = idx//scale_factor
   ...: starts = s0*scale_factor
   ...: out = mapar[s0]
   ...: out = array_masking3(out, starts, idx, sidx)
40.6 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# @silgon's soln    
In [5]: %timeit x[np.newaxis,:] < y[:,np.newaxis]
138 ms ± 896 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

方法#2

這從OP's solution借用了一個很好的部分。

import numba as nb

@nb.njit(parallel=True)
def array_masking2(mask1D, mask_out, idx, pt):
    n = len(idx)
    for j in nb.prange(len(pt)):
        if mask1D[j]:
            for i in nb.prange(pt[j],n):
                mask_out[j, idx[i]] = False
        else:
            for i in nb.prange(pt[j]):
                mask_out[j, idx[i]] = True
    return mask_out

def app2(idx, pt):
    m,n = len(pt), len(idx)      
    mask1 = pt>len(x)//2
    mask2 = np.broadcast_to(mask1[:,None], (m,n)).copy()
    return array_masking2(mask1, mask2, idx, pt)

所以,這個想法是一次,我們有超過一半的索引要設置為True ,我們在將這些行預先分配為所有True后切換到設置False 這會導致更少的內存訪問,從而顯着提升性能。

基准測試

OP的解決方案:

@nb.njit(parallel=True,fastmath=True)
def array_masking(mask, idx, pt):
    for j in nb.prange(pt.shape[0]):
        for i in nb.prange(pt[j]):
            mask[j, idx[i]] = True
    return mask

def app1(idx, pt):
    m,n = len(pt), len(idx)      
    mask = np.zeros((m, n), dtype='bool')
    return array_masking(mask, idx, pt)

時間——

In [5]: np.random.seed(0)
   ...: x = np.random.rand(1000000)
   ...: y = np.random.rand(200)

In [6]: %timeit app1(idx, pt)
264 ms ± 8.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [7]: %timeit app2(idx, pt)
165 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

這是一個替代答案,但不確定它是否正是您所需要的。

x = np.random.rand(1000000)
y = np.random.rand(200)
mask = x[np.newaxis,:] < y[:,np.newaxis]

注意:我提到這可能不是你需要的,因為你指定了numpy.searchsorted()的需要,在這里我沒有使用它,但是我得到了相同的結果。 如果它不完全符合您的需求,它可能在未來對其他人也有用;)。

時間(@DanielF 編輯)

設置:

import numpy as np

# generate test examples
x = np.random.rand(1000000)
y = np.random.rand(200)

# sort x
idx = np.argsort(x)
sorted_x = np.take_along_axis(x, idx, axis=-1)

跑步:

%%timeit   #  silgon
mask = x[np.newaxis,:] < y[:,np.newaxis]

166 ms ± 3.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


%%timeit    # Divakar
pt = np.searchsorted(sorted_x, y)
mask = app2(idx, pt)

316 ms ± 29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


%%timeit   #  f.c.
pt = np.searchsorted(sorted_x, y)
mask = np.zeros((200, 1000000), dtype='bool')
for i in range(200):
     mask[i, idx[0:pt[i]]] = True
     
466 ms ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

我已經實現了一個 numba 版本的 for 循環,它比 python 版本稍微快一些。 這是代碼:

import numba as nb
@nb.njit(parallel=True,fastmath=True)
def array_masking(mask, idx, pt):
    for j in nb.prange(pt.shape[0]):
        for i in nb.prange(pt[j]):
            mask[j, idx[i]] = True

我正在尋找進一步的加速。 有任何想法嗎?

為了更清楚,您正在尋找一種快速方法來確定小於y[i]x項的索引掩碼。 例如,如果對x項目進行排序的索引是:

np.argsort(x) = [5, 0, 2, 10, 7, 8, 9, 11, 1, 6, 3, 4]

並且您知道8項目小於y[i] ,您需要從該列表的相反順序中選擇前 8 個項目,然后:

arg_inv = [1, 8, 2, 10, 11, 0, 9, 4, 5, 6, 3, 7]

這個問題最簡單的方法是高級索引:

length_x, length_y = len(x), len(y)
idx = np.argsort(x)
arg_inv = np.argsort(idx)
pt = np.searchsorted(x, y, sorter=idx)
mask = np.zeros((length_y, length_x), dtype='bool')
row, col = np.divmod(np.arange(length_x * length_y), length_x)
mask[row, col] = arg_inv[col] < pt[row]
return mask

我還添加了一個小樣本的例子:

x = [0.809 0.958 0.881 0.146 0.882 0.421 0.604]
y = [0.119 0.981 0.775 0.254]
np.sort(x) = [0.146 0.421 0.604 0.809 0.881 0.882 0.958]
np.argsort(x) = [3 5 6 0 2 4 1]
arg_inv = [3 6 4 0 5 1 2]
pt = [0 7 3 1]
Process of advanced indexing:

    row  col  arg_inv[col]  pt[row]  arg_inv[col] < pt[row]
0     0    0             3        0                       0
1     0    1             6        0                       0
2     0    2             4        0                       0
3     0    3             0        0                       0
4     0    4             5        0                       0
5     0    5             1        0                       0
6     0    6             2        0                       0
7     1    0             3        7                       1
8     1    1             6        7                       1
9     1    2             4        7                       1
10    1    3             0        7                       1
11    1    4             5        7                       1
12    1    5             1        7                       1
13    1    6             2        7                       1
14    2    0             3        3                       0
15    2    1             6        3                       0
16    2    2             4        3                       0
17    2    3             0        3                       1
18    2    4             5        3                       0
19    2    5             1        3                       1
20    2    6             2        3                       1
21    3    0             3        1                       0
22    3    1             6        1                       0
23    3    2             4        1                       0
24    3    3             0        1                       1
25    3    4             5        1                       0
26    3    5             1        1                       0
27    3    6             2        1                       0

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM