[英]find numpy array in other numpy array
我需要在一个更大的numpy数组中找到一个小的numpy数组。 例如:
import numpy as np
a = np.array([1, 1])
b = np.array([2, 3, 3, 1, 1, 1, 8, 3, 1, 6, 0, 1, 1, 3, 4])
一个功能
find_numpy_array_in_other_numpy_array(a, b)
应该返回索引
[3, 4, 11]
表示完整numpy数组a
出现在完整numpy数组b
。
在处理非常大的b
数组时,这种问题的蛮力方法很慢:
ok = []
for idx in range(b.size - a.size + 1):
if np.all(a == b[idx : idx + a.size]):
ok.append(idx)
我正在寻找一种更快的方法来查找数组b
完整数组a
所有索引。 快速方法还应该允许其他比较函数,例如找出a
和b
之间的最坏情况差异:
diffs = []
for idx in range(b.size - a.size + 1):
bi = b[idx : idx + a.size]
diff = np.nanmax(np.abs(bi - a))
diffs.append(diff)
对于通用解决方案,我们可以创建滑动窗口的2D
阵列,然后执行相关操作 -
from skimage.util.shape import view_as_windows
b2D = view_as_windows(b,len(a))
NumPy equivalent implementation
。
问题#1
然后,为了解决匹配指数问题,它只是 -
matching_indices = np.flatnonzero((b2D==a).all(axis=1))
问题#2
为了解决第二个问题,它可以很容易地映射,记住任何用于获取输出元素的ufunc减少操作将使用该ufunc的axis
参数在建议的解决方案中沿第二轴转换为减少 -
diffs = np.nanmax(np.abs(b2D-a),axis=1)
以下代码查找数组b
序列( a
)中第一个元素的所有匹配项。 然后它创建一个新数组,其中包含可能的候选序列列,将它们与完整序列进行比较,并过滤初始索引
seq, arr = a, b
len_seq = len(seq)
ini_idx = (arr[:-len_seq+1]==seq[0]).nonzero()[0] # idx of possible sequence canditates
seq_candidates = arr[np.arange(1, len_seq)[:, None]+ini_idx] # columns with possible seq. candidates
mask = (seq_candidates==seq[1:,None]).all(axis=0)
idx = ini_idx[mask]
您可以考虑使用Numba来编译该函数。 你可以这样做:
import numpy as np
import numba as nb
@nb.njit(parallel=True)
def search_in_array(a, b):
idx = np.empty(len(b) - len(a) + 1, dtype=np.bool_)
for i in nb.prange(len(idx)):
idx[i] = np.all(a == b[i:i + len(a)])
return np.where(idx)[0]
a = np.array([1, 1])
b = np.array([2, 3, 3, 1, 1, 1, 8, 3, 1, 6, 0, 1, 1, 3, 4])
print(search_in_array(a, b))
# [ 3 4 11]
快速基准:
import numpy as np
np.random.seed(100)
a = np.random.randint(5, size=10)
b = np.random.randint(5, size=10_000_000)
# Non-compiled function
%timeit search_in_array.py_func(a, b)
# 22.8 s ± 242 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Compiled function
%timeit search_in_array(a, b)
# 54.7 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
如您所见,您可以获得大约400倍的加速,并且内存成本相对较低(布尔数组与大数组相同)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.