簡體   English   中英

查找一個數組中哪些元素與另一個元素中的任何元素接近的最有效方法是什么?

[英]What's the most efficient way to find which elements of one array are close to any element in another?

我有兩個1維numpy.ndarray對象,並想numpy.ndarray第一個數組中的哪些元素在第二個數組中的任何元素的dx內。

我現在擁有的是什么

# setup
numpy.random.seed(1)
a = numpy.random.random(1000)  # create one array
numpy.random.seed(2)
b = numpy.random.random(1000)  # create second array
dx = 1e-4  # close-ness parameter

# function I want to optimise
def find_all_close(a, b):
    # compare one number to all elements of b
    def _is_coincident(t):
        return (numpy.abs(b - t) <= dx).any()
    # vectorize and loop over a
    is_coincident = numpy.vectorize(_is_coincident)
    return is_coincident(a).nonzero()[0]

返回timeit結果如下

10 loops, best of 3: 16.5 msec per loop

優化find_all_close函數的最佳方法是什么,特別是如果ab保證是float ,當它們傳遞給find_all_close時可能會以升序排序,可能是cython或類似的?

在實踐中,我正在使用10,000到100,000個元素(或更大)的數組,並在幾百個不同的b數組上運行整個操作。

最簡單的方法是對第一個數組中的每個元素,對第二個數組進行兩次二進制搜索,以找到最多dx以下的元素,最多在第一個數組中的元素上方dx。 這是線性時間:

left = np.searchsorted(b, a - dx, 'left')
right = np.searchsorted(b, a + dx, 'right')
a[left != right]

線性算法有兩個指向第二個數組的指針,它們在迭代第一個數組中的元素時跟蹤移動窗口。

您的方法是二次的,這是一個用於排序數組的單程線性時間算法。 您只需在正確的時間運行兩個陣列。

def prox(a,b,dx):
    ia=ib=ir=0
    res=zeros(a.size,int32)
    while ia<a.size and ib<b.size:
        if abs(a[ia]-b[ib])<dx: 
            res[ir]=ia
            ir += 1
            ia += 1
        elif a[ia]>b[ib] :
               ib += 1
        else : ia += 1
    return res[:ir]      

您可以使用Numba編譯此代碼以進一步提高性能。

測試:

a=rand(1000)
b=rand(1000)
a.sort()
b.sort()

In [10]:   prox(a,b,1e-5)
Out[10]: 
array([ 35,  90, 159, 165, 174, 252, 276, 380, 383, 467, 508, 515, 641,
       658, 705, 711, 728, 814, 857, 871, 907, 945])

In [11]: %timeit prox(a,b,1e-4)
100 loops, best of 3: 6.23 ms per loop
In [12]: prox2=numba.jit(prox)
In [13]: %timeit prox2(a,b,1e-4)
10000 loops, best of 3: 19.1 µs per loop

這沒有利用數據的排序特性,因此它沒有線性時間復雜度(雖然我懷疑運行時確實從它被排序,緩存方式中受益),但是nlogn並不壞,並且肯定很難被擊敗簡單和充分測試的條款:

from scipy.spatial import cKDTree
print(cKDTree(a[:, None]).query_ball_point(b[:, None], dx))

ab是排序數組時,這種排序的性質可能會被np.searchsorted 濫用 基本思想是我們得到b索引,其左邊可以放置a每個元素,以便維護排序的順序。 這是我們都知道的約束在特定格的右側路b ,每個從元素的a可以放置。 要獲得相同bin的左側邊界,只需從先前找到的那些索引中減去1 然后,獲取每個a與左邊界和右邊界之間的差異,看看是否在閾值范圍內,如果是,則將其視為有效索引。

對於拐角情況將會有一些額外的工作,即當a中的元素大於最高b並且元素小於最低b 如果我們使用np.searchsorted'left'搜索選項,我們只需要將它剪切到最小值1以找到正確的邊界,這樣就可以一次性在整個數組中使用這些相同的索引。 因此,實現看起來像這樣 -

def find_all_close_searchsorted(a, b):
    lidx = np.searchsorted(b,a,'left').clip(min=1,max=b.size-1)
    close_mask = (np.abs(b[lidx] - a) <= dx) | (np.abs(b[lidx-1] - a) <= dx)
    return np.nonzero(close_mask)[0]

運行時測試 -

In [2]: np.random.seed(1)
   ...: a = np.sort(np.random.random(1000))  # create one array
   ...: np.random.seed(2)
   ...: b = np.sort(np.random.random(1000))  # create second array
   ...: dx = 1e-4  # close-ness parameter
   ...: 

In [3]: np.allclose(find_all_close_searchsorted(a, b),find_all_close(a, b))
Out[3]: True

In [4]: %timeit find_all_close(a,b)
100 loops, best of 3: 16 ms per loop

In [5]: %timeit find_all_close_searchsorted(a,b)
10000 loops, best of 3: 91.4 µs per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM