繁体   English   中英

使用条件(numpy.where)时更快的 numpy 数组索引?

[英]Faster numpy array indexing when using condition (numpy.where)?

我有一个巨大的 numpy 数组,形状为 (50000000, 3),我正在使用:

x = array[np.where((array[:,0] == value) | (array[:,1] == value))]

得到我想要的数组部分。 但是这种方式似乎很慢。 有没有更有效的方法来使用 numpy 执行相同的任务?

np.where是高度优化的,我怀疑有人能写出比上一个 Numpy 版本中实现的代码更快的代码(免责声明:我是优化它的人)。 也就是说,这里的主要问题不是np.where而是创建临时 boolean 数组的条件。 不幸的是,这是在 Numpy 中执行此操作的方法,只要您仅使用具有相同输入布局的 Numpy,就没什么可做的。

解释为什么它不是很有效的原因之一是输入数据布局效率低下 实际上,假设array使用默认的行主要顺序连续存储在 memory 中, array[:,0] == value将读取 memory 中数组的每 3 个项目中的 1 个项目。由于 CPU 缓存的工作方式(即缓存行) ,预取等),浪费了 memory 带宽的 2/3 事实上,output boolean 数组也需要写入,并且由于页面错误,填充新创建的数组有点慢。 请注意, array[:,1] == value肯定会由于输入的大小(无法容纳在大多数 CPU 缓存中)而从 RAM 中重新加载数据 RAM 很慢,与 CPU 和缓存的计算速度相比,它越来越慢。 这个问题称为“ memory 墙”,几十年前就已经出现,预计不会很快得到修复。 另请注意,逻辑或还将创建一个从 RAM 读/写到 RAM 的新数组。 更好的数据布局是在 memory 中连续的(3, 50000000)转置数组(注意np.transpose不会产生连续数组)。

解释性能问题的另一个原因是Numpy 往往未针对非常小的轴进行优化

一个主要的解决方案是在可能的情况下以转置的方式创建输入。 另一种解决方案是编写Numba 或 Cython 代码 这是非转置输入的实现:

# Compilation for the most frequent types. 
# Please pick the right ones so to speed up the compilation time. 
@nb.njit(['(uint8[:,::1],uint8)', '(int32[:,::1],int32)', '(int64[:,::1],int64)', '(float64[:,::1],float64)'], parallel=True)
def select(array, value):
    n = array.shape[0]
    mask = np.empty(n, dtype=np.bool_)
    for i in nb.prange(n):
        mask[i] = array[i, 0] == value or array[i, 1] == value
    return mask

x = array[select(array, value)]

请注意,我使用了并行实现,因为or运算符对于 Numba 不是最优的(唯一的解决方案似乎是使用本机代码或 Cython),而且因为在某些平台(如计算服务器)上,RAM 不能完全被一个线程饱和。 另请注意,对于 select 的结果,使用array[np.where(select(array, value))[0]]select 事实上,如果结果是随机的或非常小,那么np.where可以更快,因为它对 boolean 索引不执行的这些情况进行了特殊优化。 请注意, np.where在 Numba function 的上下文中没有特别优化,因为 Numba 使用它自己的 Numpy 函数实现,并且它们有时没有针对大型 arrays 进行优化。更快的实现包括并行创建x但这不是使用 Numba 很简单,因为提前不知道 output 项目的数量,并且线程必须知道在哪里写入数据,更不用说 Numpy 已经相当快地按顺序执行此操作,只要 output 是可预测的。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM