簡體   English   中英

基於二進制矩陣中元素包含的快速數組操作

[英]Fast array manipulation based on element inclusion in binary matrix

對於2D晶格中的大量隨機分布點,我想要有效地提取子陣列,該子陣列僅包含近似為索引的元素,這些元素被分配給單獨的2D二進制矩陣中的非零值。 目前,我的腳本如下:

lat_len = 100 # lattice length
input = np.random.random(size=(1000,2)) * lat_len
binary_matrix = np.random.choice(2, lat_len * lat_len).reshape(lat_len, -1)

def landed(input):
    output = []
    input_as_indices = np.floor(input)
    for i in range(len(input)):
        if binary_matrix[input_as_indices[i,0], input_as_indices[i,1]] == 1:
            output.append(input[i])
    output = np.asarray(output)
    return output   

但是,我懷疑必須有更好的方法來做到這一點。 上面的腳本可能需要很長時間才能運行10000次迭代。

你是對的。 上面的計算可以使用高級numpy索引在python中沒有for循環的情況下更有效地完成,

def landed2(input):
    idx = np.floor(input).astype(np.int)
    mask = binary_matrix[idx[:,0], idx[:,1]] == 1
    return input[mask]

res1 = landed(input)
res2 = landed2(input)
np.testing.assert_allclose(res1, res2)

這導致加速~150倍。

如果使用線性索引數組,似乎可以顯着提升性能。 這是一個解決我們案例的矢量化實現,類似於@rth的答案 ,但使用線性索引 -

# Get floor-ed indices
idx = np.floor(input).astype(np.int)

# Calculate linear indices 
lin_idx = idx[:,0]*lat_len + idx[:,1]

# Index raveled/flattened version of binary_matrix with lin_idx
# to extract and form the desired output
out = input[binary_matrix.ravel()[lin_idx] ==1]

因此,總之我們有:

out = input[binary_matrix.ravel()[idx[:,0]*lat_len + idx[:,1]] ==1]

運行時測試 -

本節將此解決方案中提出的方法與使用行列索引的其他解決方案進行比較。

案例#1(原始數據):

In [62]: lat_len = 100 # lattice length
    ...: input = np.random.random(size=(1000,2)) * lat_len
    ...: binary_matrix = np.random.choice(2, lat_len * lat_len).
                                             reshape(lat_len, -1)
    ...: 

In [63]: idx = np.floor(input).astype(np.int)

In [64]: %timeit input[binary_matrix[idx[:,0], idx[:,1]] == 1]
10000 loops, best of 3: 121 µs per loop

In [65]: %timeit input[binary_matrix.ravel()[idx[:,0]*lat_len + idx[:,1]] ==1]
10000 loops, best of 3: 103 µs per loop

案例#2(更大的數據量):

In [75]: lat_len = 1000 # lattice length
    ...: input = np.random.random(size=(100000,2)) * lat_len
    ...: binary_matrix = np.random.choice(2, lat_len * lat_len).
                                             reshape(lat_len, -1)
    ...: 

In [76]: idx = np.floor(input).astype(np.int)

In [77]: %timeit input[binary_matrix[idx[:,0], idx[:,1]] == 1]
100 loops, best of 3: 18.5 ms per loop

In [78]: %timeit input[binary_matrix.ravel()[idx[:,0]*lat_len + idx[:,1]] ==1]
100 loops, best of 3: 13.1 ms per loop

因此,這種線性索引的性能提升似乎約為20% 30%

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM