[英]What is the fastest way to filter a numpy array by another array?
我有一个相当大的 np.array a
(10,000-50,000 个元素,每个坐标(x,y))和另一个更大的 np.array b
(100,000-200,000 坐标)。 我需要尽快删除b
中不存在的a
元素,只留下b
中存在的a
元素。 所有坐标都是整数。 例如:
a = np.array([[2,5],[6,3],[4,2],[1,4]])
b = np.array([[2,7],[4,2],[1,5],[6,3]])
所需的 output:
a
>> [6,3],[4,2]
对于我提到的尺寸的 arrays,最快的方法是什么?
除了 Numpy 中的解决方案之外,我也可以接受使用任何其他包或导入的解决方案(例如,使用Pandas等转换为基本 Python 列表或设置)。
这似乎很大程度上取决于数组大小和“稀疏性”(可能是由于 hash 表魔术)。
Get intersecting rows across two 2D numpy arrays的答案是so_8317022
function。
外卖似乎是(在我的机器上):
from collections import defaultdict
import numpy as np
import pandas as pd
import timeit
import matplotlib.pyplot as plt
def pandas_merge(a, b):
return pd.DataFrame(a).merge(pd.DataFrame(b)).to_numpy()
def set_intersection(a, b):
return set(map(tuple, a.tolist())) & set(map(tuple, b.tolist()))
def so_8317022(a, b):
nrows, ncols = a.shape
dtype = {
"names": ["f{}".format(i) for i in range(ncols)],
"formats": ncols * [a.dtype],
}
C = np.intersect1d(a.view(dtype), b.view(dtype))
return C.view(a.dtype).reshape(-1, ncols)
def test_fn(f, a, b):
number, time_taken = timeit.Timer(lambda: f(a, b)).autorange()
return number / time_taken
def test(size, max_coord):
a = np.random.default_rng().integers(0, max_coord, size=(size, 2))
b = np.random.default_rng().integers(0, max_coord, size=(size, 2))
return {fn.__name__: test_fn(fn, a, b) for fn in (pandas_merge, set_intersection, so_8317022)}
series = []
datas = defaultdict(list)
for size in (100, 1000, 10000, 100000):
for max_coord in (50, 500, 5000):
print(size, max_coord)
series.append((size, max_coord))
for fn, result in test(size, max_coord).items():
datas[fn].append(result)
print("size", "sparseness", "func", "ops/sec")
for fn, values in datas.items():
for (size, max_coord), value in zip(series, values):
print(size, max_coord, fn, int(value))
我机器上的结果是
尺寸 | 稀疏性 | 功能 | 操作/秒 |
---|---|---|---|
100 | 50 | pandas_merge | 895 |
100 | 500 | pandas_merge | 777 |
100 | 5000 | pandas_merge | 708 |
1000 | 50 | pandas_merge | 740 |
1000 | 500 | pandas_merge | 751 |
1000 | 5000 | pandas_merge | 660 |
10000 | 50 | pandas_merge | 513 |
10000 | 500 | pandas_merge | 460 |
10000 | 5000 | pandas_merge | 436 |
100000 | 50 | pandas_merge | 11 |
100000 | 500 | pandas_merge | 61 |
100000 | 5000 | pandas_merge | 49 |
100 | 50 | set_intersection | 42281 |
100 | 500 | set_intersection | 44050 |
100 | 5000 | set_intersection | 43584 |
1000 | 50 | set_intersection | 3693 |
1000 | 500 | set_intersection | 3234 |
1000 | 5000 | set_intersection | 3900 |
10000 | 50 | set_intersection | 453 |
10000 | 500 | set_intersection | 287 |
10000 | 5000 | set_intersection | 300 |
100000 | 50 | set_intersection | 47 |
100000 | 500 | set_intersection | 13 |
100000 | 5000 | set_intersection | 13 |
100 | 50 | so_8317022 | 8927 |
100 | 500 | so_8317022 | 9736 |
100 | 5000 | so_8317022 | 7843 |
1000 | 50 | so_8317022 | 698 |
1000 | 500 | so_8317022 | 746 |
1000 | 5000 | so_8317022 | 765 |
10000 | 50 | so_8317022 | 89 |
10000 | 500 | so_8317022 | 48 |
10000 | 5000 | so_8317022 | 57 |
100000 | 50 | so_8317022 | 10 |
100000 | 500 | so_8317022 | 3 |
100000 | 5000 | so_8317022 | 3 |
不确定这是否是最快的方法,但如果你把它变成 pandas 索引,你可以使用它的交集方法。 由于它在底层使用低级 c 代码,因此交叉步骤可能非常快,但将其转换为 pandas 索引可能需要一些时间
import numpy as np
import pandas as pd
a = np.array([[2, 5], [6, 3], [4, 2], [1, 4]])
b = np.array([[2, 7], [4, 2], [1, 5], [6, 3]])
df_a = pd.DataFrame(a).set_index([0, 1])
df_b = pd.DataFrame(b).set_index([0, 1])
intersection = df_a.index.intersection(df_b.index)
结果看起来像这样
print(intersection.values)
[(6, 3) (4, 2)]
编辑2:
出于好奇,我对这些方法进行了比较。 现在有一个更大的索引列表。 我将我的第一个索引方法与一个稍微改进的方法进行了比较,该方法不需要先创建 dataframe,而是立即创建索引,然后再使用 dataframe 合并方法。
这是代码
from random import randint, seed
import time
import numpy as np
import pandas as pd
seed(0)
n_tuple = 100000
i_min = 0
i_max = 10
a = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
b = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
np_a = np.array(a)
np_b = np.array(b)
def method0(a_array, b_array):
index_a = pd.DataFrame(a_array).set_index([0, 1]).index
index_b = pd.DataFrame(b_array).set_index([0, 1]).index
return index_a.intersection(index_b).to_numpy()
def method1(a_array, b_array):
index_a = pd.MultiIndex.from_arrays(a_array.T)
index_b = pd.MultiIndex.from_arrays(b_array.T)
return index_a.intersection(index_b).to_numpy()
def method2(a_array, b_array):
df_a = pd.DataFrame(a_array)
df_b = pd.DataFrame(b_array)
return df_a.merge(df_b).to_numpy()
def method3(a_array, b_array):
set_a = {(_[0], _[1]) for _ in a_array}
set_b = {(_[0], _[1]) for _ in b_array}
return set_a.intersection(set_b)
for cnt, intersect in enumerate([method0, method1, method2, method3]):
t0 = time.time()
if cnt < 3:
intersection = intersect(np_a, np_b)
else:
intersection = intersect(a, b)
print(f"method{cnt}: {time.time() - t0}")
output 看起来像:
method0: 0.1439347267150879
method1: 0.14012742042541504
method2: 4.740894317626953
method3: 0.05933070182800293
结论:数据帧的合并方法(method2)比在索引上使用交集要慢约 50 倍。 基于multiindex(method1)的版本只比method0快一点(我的第一个建议)
EDIT2:正如@AKX 的评论所建议的:如果您不使用 numpy 而是使用纯列表和集合,您可以再次获得大约 3 倍的加速。但很明显,您不应该使用合并方法。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.