簡體   English   中英

如何更快地找到兩個多維數組的交集?

[英]How can i find the intersection of two multidimensional arrays faster?

有兩個具有不同行數的多維布爾數組。 我想快速找到公共行中 True 值的索引。 我寫了下面的代碼,但它太慢了。 有沒有更快的方法來做到這一點?

a=np.random.choice(a=[False, True], size=(100,100))
b=np.random.choice(a=[False, True], size=(1000,100))

for i in a:
    for j in b:
        if np.array_equal(i, j):
          print(np.where(i))

讓我們從一個有意義的問題的版本開始,通常會打印一些東西:

a = np.random.choice(a=[False, True], size=(2, 2))
b = np.random.choice(a=[False, True], size=(4, 2))

print(f"a: \n {a}")
print(f"b: \n {b}")

matches = []
for i, x in enumerate(a):
    for j, y in enumerate(b):
        if np.array_equal(x, y):
            matches.append((i, j))

使用scipy.cdist的解決方案將a中的所有行與b中的所有行進行比較,使用漢明距離進行布爾向量比較:

import numpy as np
import scipy
from scipy import spatial

d = scipy.spatial.distance.cdist(a, b, metric='hamming')
cdist_matches = np.where(d == 0)
mathces_values = [(a[i], b[j]) for (i, j) in matches]
cdist_values = a[cdist_matches[0]], b[cdist_matches[1]]
print(f"matches_inds = \n{matches}")
print(f"matches = \n{mathces_values}")

print(f"cdist_inds = \n{cdist_matches}")
print(f"cdist_matches =\n {cdist_values}")

出去:

a: 
 [[ True False]
 [False False]]
b: 
 [[ True  True]
 [ True False]
 [False False]
 [False  True]]
matches_inds = 
[(0, 1), (1, 2)]
matches = 
[(array([ True, False]), array([ True, False])), (array([False, False]), array([False, False]))]
cdist_inds = 
(array([0, 1], dtype=int64), array([1, 2], dtype=int64))
cdist_matches =
 (array([[ True, False],
       [False, False]]), array([[ True, False],
       [False, False]]))


如果您不想import scipy請參閱純 numpy 實現

可以通過使用np.newaxisnp.tile使 a 的形狀可廣播到 b 的形狀來比較 a 的每一行與 b 的每一行

import numpy as np

a=np.random.choice(a=[True, False], size=(2,5))
b=np.random.choice(a=[True, False], size=(10,5))
broadcastable_a = np.tile(a[:, np.newaxis, :], (1, b.shape[0], 1))
a_equal_b = np.equal(b, broadcastable_a)
indexes = np.where(a_equal_b)
indexes = np.stack(np.array(indexes[1:]), axis=1)

如果你想比較 NDarrays 元素,我會做這樣的事情:

import numpy as np

# data
a = np.random.choice(a = [False, True], size = (100,100))
b = np.random.choice(a = [False, True], size = (1000,100))

# extract matching coordinates
match = np.where((a == b[:100,:]) == True)
match = list(zip(*match))

# first 20 coordinates match
print("Number of matches:", len(match))
print(match[:20])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM