[英]Fastest way to find the nearest pairs between two numpy arrays without duplicates
[英]Fastest way to find all indexes of matching values between two 1D arrays (with duplicates)
问题描述
假设我们有两个简单的 arrays:
query = np.array([100, 4000, 500, 700, 400, 100])
match = np.array([6, 100, 4000, 100, 10, 8, 10])
我想找到查询和匹配之间所有匹配值的索引。 所以在这种情况下,结果将是:
value query match
100 0 1
100 0 3
100 5 1
100 5 3
4000 1 2
实际上,这些 arrays 将包含数百万个项目
“愚蠢”的循环解决方案
qs = []
query_locs = []
match_locs = []
for i in np.arange(query.size):
q = query[i]
# Get matching indexes in "match"
match_loc = np.where(match == q)[0]
n = match_loc.size
# Update location arrays
match_locs.extend(match_loc)
query_locs.extend(np.repeat(i,n))
# Store the matching value
qs.extend(np.repeat(q,n))
result = np.vstack((qs, query_locs, match_locs)).T
print(result)
[[ 100 0 1]
[ 100 0 3]
[4000 1 2]
[ 100 5 1]
[ 100 5 3]]
(也许numba
可以使这个循环非常快,但是当我尝试这个时,我得到了一些关于签名的错误,所以不确定)
Numpy 构建
有相当多的内置 numpy function 来解决这个唯一值的问题,比如使用searchsorted
, intersect1d
,但是,正如文档中所述,它们“返回排序的唯一值”,因此不考虑重复项。 StackOverflow 上针对此问题的一些具有唯一值的示例:
我可以想象用 numpy 而不是循环会有一种更快的方法来做到这一点,所以很想看到答案!
您可以将一维数组转换为数据帧并进行连接,如下所示:
query = np.array([100, 4000, 500, 700, 400, 100])
match = np.array([6, 100, 4000, 100, 10, 8, 10])
dfquery = pd.DataFrame(range(len(query)), index=query, columns=['query'])
dfmatch = pd.DataFrame(range(len(match)), index=match, columns=['match'])
dfquery.join(dfmatch, how='inner')
结果:
query match
100 0 1
100 0 3
100 5 1
100 5 3
4000 1 2
你可以用 newaxis 破解它:
>>> comparison = np.equal(query[:, np.newaxis], match[np.newaxis, :])
array([[False, True, False, True, False, False, False],
[False, False, True, False, False, False, False],
[False, False, False, False, False, False, False],
[False, False, False, False, False, False, False],
[False, False, False, False, False, False, False],
[False, True, False, True, False, False, False]])
它实质上创建了笛卡尔积( query
x matches
)(注意 memory 成本),然后应用二进制 function np.equal
以有效地将产品空间中的每个元素转换为bool
。 output 可以通过逐行读取来解释为:只要 compare comparison[i, j]
为True
,查询元素 i 就等于匹配元素 j 。 您可以使用以下方法收集所有True
对的索引:
list(zip(*comparison.nonzero()))
[(0, 1), (0, 3), (1, 2), (5, 1), (5, 3)]
ps:如果 arrays 太长而无法构建产品,那么逐元素迭代它们是您唯一的选择。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.