簡體   English   中英

在 Numpy 數組中查找所有接近數對的最快方法

[英]Fastest way to find all pairs of close numbers in a Numpy array

假設我有一個N = 10隨機浮點數的 Numpy 數組:

import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)

out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963 
         5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]

我想找到所有唯一的數字對,它們彼此之間的差異不超過公差tol = 1. (即絕對差異 <= 1)。 具體來說,我想獲取所有唯一的索引對。 每個close pair的索引都應該排序,所有close pair都應該按第一個索引排序。 我設法編寫了以下工作代碼:

def all_close_pairs(arr, tol=1.):
    res = set()
    for i, x1 in enumerate(arr):
        for j, x2 in enumerate(arr):
            if i == j:
                continue
            if np.isclose(x1, x2, rtol=0., atol=tol):
                res.add(tuple(sorted([i, j])))
    res = np.array(list(res))
    return res[res[:,0].argsort()]

print(all_close_pairs(arr, tol=1.))

out[2]: [[1 5]
         [2 4]
         [3 7]
         [3 9]
         [7 9]]

然而,實際上我有一個N = 1000數字的數組,並且由於嵌套的 for 循環,我的代碼變得非常慢。 我相信使用 Numpy 矢量化有更有效的方法。 有誰知道在 Numpy 中最快的方法?

問題是您的代碼具有 O(n*n) (二次)復雜度。 為了降低復雜性,您可以嘗試先對項目進行排序:

def all_close_pairs(arr, tol=1.):
    res = set()
    arr = sorted(enumerate(arr), key=lambda x: x[1])
    for (idx1, (i, x1)) in enumerate(arr):
        for idx2 in range(idx1-1, -1, -1):
            j, x2 = arr[idx2]
            if not np.isclose(x1, x2, rtol=0., atol=tol):
                break
            indices = sorted([i, j])
            res.add(tuple(indices))
    return np.array(sorted(res))

但是,這僅在您的值范圍遠大於公差時才有效。

您可以通過使用2 pointers策略進一步改進這一點,但總體復雜性將保持不變。

這是一個純 numpy 操作的解決方案。 在我的機器上看起來相當快,但我不知道我們正在尋找什么樣的速度。

def all_close_pairs(arr, tol=1.):
    N = arr.shape[0]
    # get indices in the array to consider using meshgrid
    pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
    # filter out pairs so we get indices in increasing order
    pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
    # compare indices in your array for closeness
    is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
    return pair_coords[is_close, :]

一種有效的解決方案是首先使用index = np.argsort()對輸入值進行排序 然后,您可以使用arr[index]生成排序數組,然后如果快速連續數組上的對數較少,則在准線性時間內迭代接近的值。 如果對的數量很大,那么由於生成的對數是二次的,所以復雜度是二次的。 得到的復雜度是: O(n log n + m)其中n是輸入數組的大小, m是產生的對數。

要找到彼此接近的值,一種有效的方法是使用Numba迭代值。 事實上,雖然在 Numpy 中可能是可能的,但由於要比較的值的數量可變,它可能效率不高。 這是一個實現:

import numba as nb

@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
    res = []
    for i in range(arr.size):
        val = arr[i]
        # Iterate over the close numbers (only once)
        for j in range(i+1, arr.size):
            # Sadly neither np.isclose or np.abs are implemented in Numba so far
            if max(val, arr[j]) - min(val, arr[j]) >= tol:
                break
            res.append((i, j))
    if len(res) == 0: # No pairs: we need to help Numpy to know the shape
        return np.empty((0, 2), dtype=np.int32)
    return np.array(res, dtype=np.int32)

最后,需要更新索引以引用未排序數組中的索引而不是排序數組。 您可以使用index[result]來做到這一點。

這是生成的代碼:

index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])

這是結果(順序與問題中的不同,但您可以根據需要對其進行排序):

array([[9, 3],
       [9, 7],
       [3, 7],
       [1, 5],
       [4, 2]])

提高算法的復雜度

如果您需要更快的算法,那么您可以使用另一種 output 格式:您可以為每個輸入值提供接近目標輸入值的最小值/最大值范圍。 要查找范圍,您可以在已排序的數組上使用二進制搜索(請參閱: np.searchsorted )。 生成的算法在O(n log n)中運行。 但是,您無法獲取未排序數組中的索引,因為該范圍是不連續的。

您可以首先使用 itertools.combinations 創建組合:

def all_close_pairs(arr, tolerance):
    pairs = list(combinations(arr, 2))
    indexes = list(combinations(range(len(arr)), 2))
    all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <=  tolerance]
    return all_close_pairs_indexes

現在,對於 N=1000,您將只需要比較 499500 對而不是 100 萬對!

它比原始代碼快約 900 倍。 對於 N=1000,在我的機器上大約需要 0.15 秒。

有點晚了,但所有 numpy 解決方案:

import numpy as np

def close_enough( arr, tol = 1 ): 
    result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], r_tol = 0.0, atol = tol ), 1)) 
    return np.swapaxes( result, 0, 1 ) 

展開以解釋正在發生的事情

def close_enough( arr, tol = 1 ):
    bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
    # is_close generates a square array after comparing all elements with all elements.  

    bool_arr = np.triu( bool_arr, 1 ) 
    # Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal 
    # and all elements below and to the left.

    result = np.where( bool_arr )  # Return the row and column indices for Trues
    return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns 

N = 1000,arr = 浮點數數組

%timeit close_enough( arr, tol = 1 )                                                                              
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [19]: %timeit all_close_pairs( arr, tol = 1 )                                                                           
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()                                            
# True

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM