繁体   English   中英

用另一个阵列过滤 numpy 阵列的最快方法是什么?

[英]What is the fastest way to filter a numpy array by another array?

我有一个相当大的 np.array a (10,000-50,000 个元素,每个坐标(x,y))和另一个更大的 np.array b (100,000-200,000 坐标)。 我需要尽快删除b中不存在的a元素,只留下b中存在的a元素。 所有坐标都是整数。 例如:

a = np.array([[2,5],[6,3],[4,2],[1,4]])
b = np.array([[2,7],[4,2],[1,5],[6,3]])

所需的 output:

a

>> [6,3],[4,2]

对于我提到的尺寸的 arrays,最快的方法是什么?

除了 Numpy 中的解决方案之外,我也可以接受使用任何其他包或导入的解决方案(例如,使用Pandas等转换为基本 Python 列表或设置)。

这似乎很大程度上取决于数组大小和“稀疏性”(可能是由于 hash 表魔术)。

Get intersecting rows across two 2D numpy arrays的答案是so_8317022 function。

外卖似乎是(在我的机器上):

  • Pandas 方法具有大型稀疏集的优势
  • 集合交集非常非常快,数组大小很小(尽管它返回一个集合,而不是 numpy 数组)
  • 其他 Numpy 答案可以比设置更大数组大小的交集更快。
from collections import defaultdict

import numpy as np
import pandas as pd
import timeit
import matplotlib.pyplot as plt


def pandas_merge(a, b):
    return pd.DataFrame(a).merge(pd.DataFrame(b)).to_numpy()


def set_intersection(a, b):
    return set(map(tuple, a.tolist())) & set(map(tuple, b.tolist()))


def so_8317022(a, b):
    nrows, ncols = a.shape
    dtype = {
        "names": ["f{}".format(i) for i in range(ncols)],
        "formats": ncols * [a.dtype],
    }
    C = np.intersect1d(a.view(dtype), b.view(dtype))
    return C.view(a.dtype).reshape(-1, ncols)


def test_fn(f, a, b):
    number, time_taken = timeit.Timer(lambda: f(a, b)).autorange()
    return number / time_taken


def test(size, max_coord):
    a = np.random.default_rng().integers(0, max_coord, size=(size, 2))
    b = np.random.default_rng().integers(0, max_coord, size=(size, 2))
    return {fn.__name__: test_fn(fn, a, b) for fn in (pandas_merge, set_intersection, so_8317022)}


series = []
datas = defaultdict(list)

for size in (100, 1000, 10000, 100000):
    for max_coord in (50, 500, 5000):
        print(size, max_coord)
        series.append((size, max_coord))
        for fn, result in test(size, max_coord).items():
            datas[fn].append(result)

print("size", "sparseness", "func", "ops/sec")
for fn, values in datas.items():
    for (size, max_coord), value in zip(series, values):
        print(size, max_coord, fn, int(value))

我机器上的结果是

尺寸 稀疏性 功能 操作/秒
100 50 pandas_merge 895
100 500 pandas_merge 777
100 5000 pandas_merge 708
1000 50 pandas_merge 740
1000 500 pandas_merge 751
1000 5000 pandas_merge 660
10000 50 pandas_merge 513
10000 500 pandas_merge 460
10000 5000 pandas_merge 436
100000 50 pandas_merge 11
100000 500 pandas_merge 61
100000 5000 pandas_merge 49
100 50 set_intersection 42281
100 500 set_intersection 44050
100 5000 set_intersection 43584
1000 50 set_intersection 3693
1000 500 set_intersection 3234
1000 5000 set_intersection 3900
10000 50 set_intersection 453
10000 500 set_intersection 287
10000 5000 set_intersection 300
100000 50 set_intersection 47
100000 500 set_intersection 13
100000 5000 set_intersection 13
100 50 so_8317022 8927
100 500 so_8317022 9736
100 5000 so_8317022 7843
1000 50 so_8317022 698
1000 500 so_8317022 746
1000 5000 so_8317022 765
10000 50 so_8317022 89
10000 500 so_8317022 48
10000 5000 so_8317022 57
100000 50 so_8317022 10
100000 500 so_8317022 3
100000 5000 so_8317022 3

不确定这是否是最快的方法,但如果你把它变成 pandas 索引,你可以使用它的交集方法。 由于它在底层使用低级 c 代码,因此交叉步骤可能非常快,但将其转换为 pandas 索引可能需要一些时间

import numpy as np
import pandas as pd

a = np.array([[2, 5], [6, 3], [4, 2], [1, 4]])
b = np.array([[2, 7], [4, 2], [1, 5], [6, 3]])

df_a = pd.DataFrame(a).set_index([0, 1])
df_b = pd.DataFrame(b).set_index([0, 1])
intersection = df_a.index.intersection(df_b.index)

结果看起来像这样

print(intersection.values)
[(6, 3) (4, 2)]

编辑2:

出于好奇,我对这些方法进行了比较。 现在有一个更大的索引列表。 我将我的第一个索引方法与一个稍微改进的方法进行了比较,该方法不需要先创建 dataframe,而是立即创建索引,然后再使用 dataframe 合并方法。

这是代码

from random import randint, seed
import time
import numpy as np
import pandas as pd

seed(0)

n_tuple = 100000
i_min = 0
i_max = 10
a = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
b = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
np_a = np.array(a)
np_b = np.array(b)


def method0(a_array, b_array):
    index_a = pd.DataFrame(a_array).set_index([0, 1]).index
    index_b = pd.DataFrame(b_array).set_index([0, 1]).index
    return index_a.intersection(index_b).to_numpy()


def method1(a_array, b_array):
    index_a = pd.MultiIndex.from_arrays(a_array.T)
    index_b = pd.MultiIndex.from_arrays(b_array.T)
    return index_a.intersection(index_b).to_numpy()


def method2(a_array, b_array):
    df_a = pd.DataFrame(a_array)
    df_b = pd.DataFrame(b_array)
    return df_a.merge(df_b).to_numpy()


def method3(a_array, b_array):
    set_a = {(_[0], _[1]) for _ in a_array}
    set_b = {(_[0], _[1]) for _ in b_array}
    return set_a.intersection(set_b)


for cnt, intersect in enumerate([method0, method1, method2, method3]):
    t0 = time.time()
    if cnt < 3:
        intersection = intersect(np_a, np_b)
    else:
        intersection = intersect(a, b)
    print(f"method{cnt}: {time.time() - t0}")

output 看起来像:

method0: 0.1439347267150879
method1: 0.14012742042541504
method2: 4.740894317626953
method3: 0.05933070182800293

结论:数据帧的合并方法(method2)比在索引上使用交集要慢约 50 倍。 基于multiindex(method1)的版本只比method0快一点(我的第一个建议)

EDIT2:正如@AKX 的评论所建议的:如果您不使用 numpy 而是使用纯列表和集合,您可以再次获得大约 3 倍的加速。但很明显,您不应该使用合并方法。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM