繁体   English   中英

Python在部分更改的数组中查找最大值索引的最有效方法

[英]Python most efficient way to find index of maximum in partially changed array

我有一个大约 750000 个元素的复值数组,我重复(比如 10^6 次或更多次)更新 1000 个(或更少)不同的元素。 在绝对平方数组中,我需要找到最大值的索引。 这是运行大约需要 700 秒的较大代码的一部分。 其中,通常 75%(约 550 秒)用于查找最大值的索引。 尽管ndarray.argmax()根据https://stackoverflow.com/a/26820109/5269892 “非常快”,但在 750000 个元素的数组上重复运行它(即使只更改了 1000 个元素)也只需要很多时间。

下面是一个最小的完整示例,其中我使用了随机数和索引。 您不能假设实值数组'b'在更新后如何变化(即值可能更大、更小或相等),除非必须,该数组位于前一个最大值的索引处( 'b[imax]' )在更新后会变小。

我尝试使用排序数组,其中只有更新的值(按排序顺序)插入到正确的位置以保持排序,因为那时我们知道最大值总是在索引-1处,我们不必重新计算它。 下面的最小示例包括时间。 不幸的是,选择未更新的值并插入更新的值需要太多时间(所有其他步骤组合起来只需要 ~210 musec 而不是ndarray.argmax()的 ~580 musec)。

上下文:这是高效 Clark (1980) 变体中反卷积算法 CLEAN (Hoegbom, 1974) 实现的一部分。 当我打算实现 Sequence CLEAN 算法(Bose+,2002)时,需要更多的迭代,或者可能想要使用更大的输入数组,我的问题是:

问题:在更新的数组中找到最大值索引的最快方法是什么(在每次迭代中不将ndarray.argmax()应用于整个数组)?

最小示例代码(在python 3.7.6, numpy 1.21.2, scipy 1.6.0上运行):

import numpy as np

# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel'; here
# 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000

# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued), and
# two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))

# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]

def do_update_iter(a, b):
    # find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
    # index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
    imax = b.argmax()
    ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
    a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
    a[ichange] = a_change
    b[ichange] = a_change.real ** 2 + a_change.imag ** 2
    return a, b, ichange

# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)

# sort 'a', 'b', 'inu_sel' and 'im_sel'
i_sort = b.argsort()
a_sort, b_sort, inu_sort, im_sort = a[i_sort], b[i_sort], inu_sel[i_sort], im_sel[i_sort]

# do an update iteration on 'a_sort' and 'b_sort'
a_sort, b_sort, ichange = do_update_iter(a_sort, b_sort)
b_sort_copy = b_sort.copy()

ind = np.arange(nsel)

def binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange):
    # binary insertion as an idea to save computation time relative to repeated argmax over entire (large) arrays
    # find updated values for 'a_sort', compute updated values for 'b_sort'
    a_change = a_sort[ichange]
    b_change = a_change.real ** 2 + a_change.imag ** 2
    # sort the updated values for 'a_sort' and 'b_sort' as well as the corresponding indices
    i_sort = b_change.argsort()
    a_change_sort = a_change[i_sort]
    b_change_sort = b_change[i_sort]
    inu_change_sort = inu_sort[ichange][i_sort]
    im_change_sort = im_sort[ichange][i_sort]
    # find indices of the non-updated values, cut out those indices from 'a_sort', 'b_sort', 'inu_sort' and 'im_sort'
    ind_complement = np.delete(ind, ichange)
    a_complement = a_sort[ind_complement]
    b_complement = b_sort[ind_complement]
    inu_complement = inu_sort[ind_complement]
    im_complement = im_sort[ind_complement]
    # find indices where sorted updated elements would have to be inserted into the sorted non-updated arrays to keep
    # the merged arrays sorted and insert the elements at those indices
    i_insert = b_complement.searchsorted(b_change_sort)
    a_updated = np.insert(a_complement, i_insert, a_change_sort)
    b_updated = np.insert(b_complement, i_insert, b_change_sort)
    inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
    im_updated = np.insert(im_complement, i_insert, im_change_sort)

    return a_updated, b_updated, inu_updated, im_updated

# do the binary insertion
a_updated, b_updated, inu_updated, im_updated = binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)

# do all the steps of the binary insertion, just to have the variable names defined
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
ind_complement = np.delete(ind, i_sort)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)

# timings for argmax and for sorting
%timeit b.argmax()             # 579 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit b_sort.argmax()        # 580 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.sort(b)             # 70.2 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.sort(b_sort)        # 25.2 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit b_sort_copy.sort()     # 14 ms ± 78.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# timings for binary insertion
%timeit binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)          # 33.7 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a_change = a_sort[ichange]                                         # 4.28 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change = a_change.real ** 2 + a_change.imag ** 2                 # 8.25 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit i_sort = b_change.argsort()                                        # 35.6 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_change_sort = a_change[i_sort]                                   # 4.2 µs ± 62.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change_sort = b_change[i_sort]                                   # 2.05 µs ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inu_change_sort = inu_sort[ichange][i_sort]                        # 4.47 µs ± 38 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit im_change_sort = im_sort[ichange][i_sort]                          # 4.51 µs ± 48.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit ind_complement = np.delete(ind, ichange)                           # 1.38 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit a_complement = a_sort[ind_complement]                              # 3.52 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_complement = b_sort[ind_complement]                              # 1.44 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_complement = inu_sort[ind_complement]                          # 1.36 ms ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_complement = im_sort[ind_complement]                            # 1.31 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit i_insert = b_complement.searchsorted(b_change_sort)                # 148 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_updated = np.insert(a_complement, i_insert, a_change_sort)       # 3.08 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_updated = np.insert(b_complement, i_insert, b_change_sort)       # 1.37 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_updated = np.insert(inu_complement, i_insert, inu_change_sort) # 1.41 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_updated = np.insert(im_complement, i_insert, im_change_sort)    # 1.52 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

根据https://stackoverflow.com/a/26820109/5269892,ndarray.argmax ndarray.argmax()是“非常快”

Argmax 不是最佳的,因为它无法使我的机器上的 RAM 带宽饱和(这是可能的),但它非常好,因为它在您的情况下使总 RAM 吞吐量的约 40% 和大约 65%-70% 饱和在我的机器上顺序(一个核心不能使大多数机器上的 RAM 饱和)。 大多数机器的吞吐量较低,因此np.argmax应该更接近这些机器上的最优值。

使用多个线程找到最大值有助于达到最佳状态,但就功能的当前性能而言,大多数 PC 上的加速不应超过 2(在计算服务器上更多)。

在更新的数组中找到最大值的索引的最快方法是什么

无论进行何种计算,读取内存中的整个数组至少需要b.size * 8 / RAM_throughput秒。 使用非常好的 2 通道 DDR4 RAM,最佳时间约为 ~125 us,而最好的 1 通道 DDR4 RAM 达到 ~225 us。 如果数组是就地写入的,则最佳时间会大两倍,如果创建一个新数组(异地计算),那么它在 x86-64 平台上会大 3 倍。 事实上,由于操作系统虚拟内存的巨大开销,这对后者来说更糟。

这意味着读取整个数组的任何异地计算都np.argmax主流 PC上击败 np.argmax 。 这也解释了为什么排序解决方案如此缓慢:它创建了许多临时数组 即使是完美的排序数组策略也不会比这里的np.argmax快多少(因为在最坏的情况下,所有项目都需要在 RAM 中移动,并且平均远远超过一半)。 事实上,写入整个数组的任何就地方法的好处都很低(仍然在主流 PC 上):它只会比np.argmax稍微快一点。 获得显着加速的唯一解决方案是不对整个阵列进行操作。

解决此问题的一种有效解决方案是使用平衡二叉搜索树 实际上,您可以在O(k log n)时间内从包含n个节点的树中删除k个节点。 然后,您可以同时插入更新的值。 在您的情况下,这比O(n)解决方案要好得多,因为 n ~= 75000 和 k ~= 1000。不过,请注意,复杂性背后有一个隐藏因素,二叉搜索树在实践中可能不会那么快,尤其是如果它们不是很优化。 另请注意,更新树值比删除节点并插入新节点更好。 在这种情况下,纯 Python 实现肯定没有机会变得很快。 只有 **Cython 或本机解决方案可以快速(例如 C/C++ 或任何本机实现的 Python 模块,但我找不到任何快速的模块)。

另一种选择是基于静态 n 元树的部分最大值数据结构 它包括将数组分成块并首先预先计算每个块的最大值。 更新值时(并假设项目数是恒定的),您需要(1) 重新计算每个 chunk 的最大值 要计算全局最大值,您需要(2) 计算每个块最大值的最大值 该解决方案还需要(半)本机实现,以便快速,因为 Numpy 将在更新每块最大值期间引入大量开销。 例如,可以使用 Numba 和 Cython 来执行此操作。 需要仔细选择块的大小。 在您的情况下,8 到 16 之间的值应该可以显着加快速度。

对于大小为 8 的块,最多只需要读取 8*k=8000 个值来重新计算总最大值(最多写入 1000 个值)。 这远远小于 75000。部分最大值的更新需要计算 an/8 ~= 9400 的最大值,该值仍然相对较小。 我希望这至少快两倍,但肯定快 4 倍。 并行实现应该更快一些。 这当然是最好的解决方案(没有额外的假设)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM