繁体   English   中英

在数组中找到下一个最大值的更快方法

[英]Faster way to find the next greatest value in array

我正在尝试改善我的 dataframe 代码的运行时间。 我的思路错了吗?

我有以下代码来查找第 1 列中高于value且索引高于它的第一个值( index_value=n

new_index=(df[n:,1] > value).argmax()

我的问题是:argmax() 参数将构建一个包含 True 和 False 的完整列表,然后,它才会找到第一次出现并返回我预期的索引。

有没有办法改进这段代码? 即在找到第一个True后停止构建列表。

本来不打算发帖的。 我期待 numba 在所有条件下都能获胜,但这并不是注定的。 对提议的解决方案进行了一些基准测试,结果有些有趣,因此在此处发布。 我将使用数组数据来保持简单。

# Proposed solutions
import numpy as np
from numba import njit

# @piRSquared's soln
@njit
def find_first_gt(a, n, value):
    while a[n] <= value:
        n += 1
    return n

# @Ehsan's soln
def numpy_argmax(a, n , value):
    return np.argmax(a[n:] > value)

使用benchit package(几个基准测试工具打包在一起;免责声明:我是它的作者)对建议的解决方案进行基准测试。

计时和加速 -

# Benchmark
a = np.arange(1000_000)
n = 0

import benchit
funcs = [find_first_gt, numpy_argmax]
vs = np.linspace(0, len(a)-1, num=20, endpoint=True).astype(int)
inputs = [(a,0,v) for v in vs]
t = benchit.timings(funcs, inputs, multivar=True, input_name='Position of value')
t.plot(logy=False, logx=False, savepath='plot.png')
t.speedups(ref_func_by_index=1).plot('Speedup_with_numba.png')

在此处输入图像描述

在此处输入图像描述

如果您对确切的加速数字感兴趣 -

In [12]: t.speedups(ref_func_by_index=1)
Out[12]: 
Functions          find_first_gt  Ref:numpy_argmax
Position of value                                 
0                    2103.548010               1.0
52631                  22.053699               1.0
105263                 11.109615               1.0
157894                  7.541725               1.0
210526                  5.640514               1.0
263157                  4.407300               1.0
315789                  3.642989               1.0
368420                  3.028726               1.0
421052                  2.543713               1.0
473683                  2.201336               1.0
526315                  1.931540               1.0
578946                  1.692138               1.0
631578                  1.536912               1.0
684209                  1.455065               1.0
736841                  1.357728               1.0
789472                  1.248716               1.0
842104                  1.176199               1.0
894735                  1.062174               1.0
947367                  1.043791               1.0
999999                  0.983419               1.0

结论:在几乎所有情况下, numba都做得很好,除非您知道该value位于最远端,或者 numba 缓存方案对您不利。

假设您将数据帧转换为 numpy:

使用np.argmax(df[n:,1] > value) 它在第一个值处停止。 当与搜索数组的大小相比,第一次出现更接近 n 时,它比(df[n:,1] > value).argmax() 然而,随着第一次出现越来越接近数组的末尾,两种方法都必须通过数组的大部分来 go 。

要按索引号将列转换为 numpy 数组:

np.argmax(df.iloc[:, 1].to_numpy()[n:] > value)

更新:比较时间:
np.arange(1,000,000)中查找元素999,998

np.argmax(df[n:,1] > value)    time = 0.0008049319999998694
(df[n:,1] > value).argmax()    time = 0.0013422100000000103
Using numba while loop         time = 0.14520884199999995

编辑:请查看@piRSquared 的时间比较答案 numpy v. numba 的表现似乎在该答案中具有可比性。 我不确定为什么它在两种设置下有所不同。

IIUC,并借用@Divakar 的建议

from numba import njit

@njit
def find_first_gt(a, n, value):
    while a[n] <= value:
        n += 1
    return n

find_first_gt(df[1].to_numpy(), n, value)

在一个简单的测试下,我们发现 while 循环的速度大约是 numpy 的两倍。

a = np.arange(1_000_000)
n = 0
value = 999_998

%timeit np.argmax(a > value)
%timeit find_first_gt(a, n, value)

322 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
620 µs ± 66.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

但是,此测试明确测试索引何时为倒数第二个 position。 平均而言,该指数将处于中间位置。 因此,让我们测试数组中的所有值。

def test_numpy(a, n):
    for value in a[::1000]:
        np.argmax(a > value)

def test_find_first(a, n):
    for value in a[::1000]:
        find_first_gt(a, n, value)

a = np.arange(1_000_000)
n = 0

%timeit test_numpy(a, n)
%timeit test_find_first(a, n)

300 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
276 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

其中我们发现平均结果大致相同。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM