[英]Fastest way to find all unique pairs of (nearly) parallel 3d vectors from N vectors in Numpy
[英]Fastest way to find all pairs of close numbers in a Numpy array
假設我有一個N = 10
隨機浮點數的 Numpy 數組:
import numpy as np
np.random.seed(99)
N = 10
arr = np.random.uniform(0., 10., size=(N,))
print(arr)
out[1]: [6.72278559 4.88078399 8.25495174 0.31446388 8.08049963
5.6561742 2.97622499 0.46695721 9.90627399 0.06825733]
我想找到所有唯一的數字對,它們彼此之間的差異不超過公差tol = 1.
(即絕對差異 <= 1)。 具體來說,我想獲取所有唯一的索引對。 每個close pair的索引都應該排序,所有close pair都應該按第一個索引排序。 我設法編寫了以下工作代碼:
def all_close_pairs(arr, tol=1.):
res = set()
for i, x1 in enumerate(arr):
for j, x2 in enumerate(arr):
if i == j:
continue
if np.isclose(x1, x2, rtol=0., atol=tol):
res.add(tuple(sorted([i, j])))
res = np.array(list(res))
return res[res[:,0].argsort()]
print(all_close_pairs(arr, tol=1.))
out[2]: [[1 5]
[2 4]
[3 7]
[3 9]
[7 9]]
然而,實際上我有一個N = 1000
數字的數組,並且由於嵌套的 for 循環,我的代碼變得非常慢。 我相信使用 Numpy 矢量化有更有效的方法。 有誰知道在 Numpy 中最快的方法?
問題是您的代碼具有 O(n*n) (二次)復雜度。 為了降低復雜性,您可以嘗試先對項目進行排序:
def all_close_pairs(arr, tol=1.):
res = set()
arr = sorted(enumerate(arr), key=lambda x: x[1])
for (idx1, (i, x1)) in enumerate(arr):
for idx2 in range(idx1-1, -1, -1):
j, x2 = arr[idx2]
if not np.isclose(x1, x2, rtol=0., atol=tol):
break
indices = sorted([i, j])
res.add(tuple(indices))
return np.array(sorted(res))
但是,這僅在您的值范圍遠大於公差時才有效。
您可以通過使用2 pointers
策略進一步改進這一點,但總體復雜性將保持不變。
這是一個純 numpy 操作的解決方案。 在我的機器上看起來相當快,但我不知道我們正在尋找什么樣的速度。
def all_close_pairs(arr, tol=1.):
N = arr.shape[0]
# get indices in the array to consider using meshgrid
pair_coords = np.array(np.meshgrid(np.arange(N), np.arange(N))).T
# filter out pairs so we get indices in increasing order
pair_coords = pair_coords[pair_coords[:, :, 0] < pair_coords[:, :, 1]]
# compare indices in your array for closeness
is_close = np.isclose(arr[pair_coords[:, 0]], arr[pair_coords[:, 1]], rtol=0, atol=tol)
return pair_coords[is_close, :]
一種有效的解決方案是首先使用index = np.argsort()
對輸入值進行排序。 然后,您可以使用arr[index]
生成排序數組,然后如果快速連續數組上的對數較少,則在准線性時間內迭代接近的值。 如果對的數量很大,那么由於生成的對數是二次的,所以復雜度是二次的。 得到的復雜度是: O(n log n + m)
其中n
是輸入數組的大小, m
是產生的對數。
要找到彼此接近的值,一種有效的方法是使用Numba迭代值。 事實上,雖然在 Numpy 中可能是可能的,但由於要比較的值的數量可變,它可能效率不高。 這是一個實現:
import numba as nb
@nb.njit('int32[:,::1](float64[::1], float64)')
def findCloseValues(arr, tol):
res = []
for i in range(arr.size):
val = arr[i]
# Iterate over the close numbers (only once)
for j in range(i+1, arr.size):
# Sadly neither np.isclose or np.abs are implemented in Numba so far
if max(val, arr[j]) - min(val, arr[j]) >= tol:
break
res.append((i, j))
if len(res) == 0: # No pairs: we need to help Numpy to know the shape
return np.empty((0, 2), dtype=np.int32)
return np.array(res, dtype=np.int32)
最后,需要更新索引以引用未排序數組中的索引而不是排序數組。 您可以使用index[result]
來做到這一點。
這是生成的代碼:
index = arr.argsort()
result = findCloseValues(arr[index], 1.0)
print(index[result])
這是結果(順序與問題中的不同,但您可以根據需要對其進行排序):
array([[9, 3],
[9, 7],
[3, 7],
[1, 5],
[4, 2]])
如果您需要更快的算法,那么您可以使用另一種 output 格式:您可以為每個輸入值提供接近目標輸入值的最小值/最大值范圍。 要查找范圍,您可以在已排序的數組上使用二進制搜索(請參閱: np.searchsorted
)。 生成的算法在O(n log n)
中運行。 但是,您無法獲取未排序數組中的索引,因為該范圍是不連續的。
您可以首先使用 itertools.combinations 創建組合:
def all_close_pairs(arr, tolerance):
pairs = list(combinations(arr, 2))
indexes = list(combinations(range(len(arr)), 2))
all_close_pairs_indexes = [indexes[i] for i,pair in enumerate(pairs) if abs(pair[0] - pair[1]) <= tolerance]
return all_close_pairs_indexes
現在,對於 N=1000,您將只需要比較 499500 對而不是 100 萬對!
它比原始代碼快約 900 倍。 對於 N=1000,在我的機器上大約需要 0.15 秒。
有點晚了,但所有 numpy 解決方案:
import numpy as np
def close_enough( arr, tol = 1 ):
result = np.where( np.triu(np.isclose( arr[ :, None ], arr[ None, : ], r_tol = 0.0, atol = tol ), 1))
return np.swapaxes( result, 0, 1 )
展開以解釋正在發生的事情
def close_enough( arr, tol = 1 ):
bool_arr = np.isclose( arr[ :, None ], arr[ None, : ], rtol = 0.0, atol = tol )
# is_close generates a square array after comparing all elements with all elements.
bool_arr = np.triu( bool_arr, 1 )
# Keep the upper right triangle, offset by 1 column. i.e. zero the main diagonal
# and all elements below and to the left.
result = np.where( bool_arr ) # Return the row and column indices for Trues
return np.swapaxes( result, 0, 1 ) # Return the pairs in rows rather than columns
N = 1000,arr = 浮點數數組
%timeit close_enough( arr, tol = 1 )
14.1 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [19]: %timeit all_close_pairs( arr, tol = 1 )
54.3 ms ± 268 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
(close_enough( arr, tol = 1) == all_close_pairs( arr, tol = 1 )).all()
# True
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.