繁体   English   中英

在列表中查找距离最小的 N 个最大元素

[英]Find N largest elements in a list with a minimum distance

我想从列表中提取 N 个最大的元素,但我想要任何两个元素x[i]x[j]abs(ij) > min_distance

scipy.signal.find_peaks(x, distance=min_distance)提供了这个功能。 但是,我需要重复此操作数百万次,并且我试图加快操作速度。

我注意到find_peaks不接受参数N来指示您要提取的峰值数量。 它也不允许从最大到最小返回峰值,需要额外调用l.sort()l = l[:N]

我试图编写一个惰性排序器,它只查找 N 个最大的元素,而不对列表的其余部分进行排序。

根据此处获得的结果,我选择了heapq 这是我的尝试:

import heapq

def new_find_peaks(x, N, min_distance=0):
    x = enumerate(x)

    x = [(-val,i) for (i,val) in x]
    heapq.heapify(x)

    val, pos = heapq.heappop(x)
    peaks = [(-val, pos,)]

    while len(peaks)<N:

        while True:
            val, pos = heapq.heappop(x)
            d = min([abs(pos - pos_i) for _,pos_i in peaks])
            if d >= min_distance:
                break

        peaks.append((-val, pos,))

    return map(list, zip(*peaks)) #Transpose peaks into 2 lists

然而,这仍然比find_peaks慢 20 倍,可能是由于find_peaks CPython 实现。 另外,我注意到几乎一半的时间都花在了

x = [(-val,i) for (i,val) in x]

你有什么更好的主意来加速这个操作吗?

--- 最小的可重复示例 ---

例如:

x = [-8.11, -7.33, -7.48, -5.77, -8.73, -8.73, -7.02, -7.02,
 -7.80, -10.92, -9.36, -9.83, -10.14, -10.77, -11.23, -9.20,
 -9.52, -9.67, -11.23, -9.98, -7.95, -9.83, -8.89, -7.33,
 -4.20, -4.05, -6.70, -7.02, -9.20, -9.21]

new_find_peaks(x, N=3, min_distance=5)

>> [[-4.05, -5.77, -7.8], [25, 3, 8]]

X

请注意, x[24]为 -4.2,但由于x[25]更大且25-24 < min_distance ,因此将其丢弃。 另请注意, x[8]不是真正的峰值,因为x[7]更大,但由于与x[3]的距离而被丢弃。 这是预期的行为。

在 python 中改进你的代码可能会给你一些改进,但由于你的代码看起来很干净而且算法的想法很合理,我认为你不会用 python 方法击败find_peaks

因此我建议你用一种更接近金属的语言编写你自己的库,如果你需要 python 中的结果,请编写你自己的 python 包装器。 例如,您可以使用 Swift。 是 Swift 中堆队列的一个实现,在这里您可以找到描述的一种与 python 接口的方法。

连接点留作练习。 ;)

为了加快计算速度,我写了第二个不使用heapq的算法。 这避免了在堆队列中重塑整个列表。

新算法看起来像这样

from heapq import nlargest

def find_peaks_fast(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)

该算法扫描列表一次并提取之前 N 个样本和之后 N 个样本中较高的所有样本。 然后将它们存储在一个列表中,使用 heapq.nlargest 从中仅提取 nlargest 元素

就其本身而言,这将执行时间降低到 3.7 毫秒。 很快,但仍然比 scipy 的 find_peaks 慢近 4 倍。

但是,这可以使用包numba进行更改。 这旨在即时“编译”python 代码并执行编译版本以提高速度。 它改善了很多

from numba import njit
from heapq import nlargest

from numba.errors import NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

@njit
def find_peaks_fast(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
        last_i = i

    elif val > last_peak:
        last_peak = val
        last_i = i

return nlargest(N,peaks)

并测试它

from scipy.signal import find_peaks

from timeit import repeat
from numpy.random import randn

from numba import njit
from heapq import nlargest


def new_find_peaks2(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)

@njit
def new_find_peaks2_jit(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)


num = 500
rep = 10

N = 20
x = randn(20000)
sep = 10

code1 = '''
i_pks, _  = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks = pks[:N]
'''

code2 = '''
_ = new_find_peaks2(x, N=N, min_distance=sep)
'''

code2_jit = '''
_ = new_find_peaks2_jit(x, N=N, min_distance=sep)
'''

i_pks, _  = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks1 = pks[:N]
pks2 = new_find_peaks2(x, N=N, min_distance=sep)

print(pks1==pks2)

t = min(repeat(stmt=code1, globals=globals(), number=num, repeat=rep))/num
print(f'np.find_peaks:\t\t{t*1000} [ms]')

t = min(repeat(stmt=code2, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2:\t{t*1000} [ms]')

t = min(repeat(stmt=code2_jit, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2_jit:\t{t*1000} [ms]')

导致结果:

np.find_peaks:          1.1234994470141828 [ms]
new_find_peaks2:      3.565517600043677 [ms]
new_find_peaks2_jit:  0.10387242998695001 [ms]

那是一个 x10 的加速!

结论:

  • 该算法可以加速
  • numba.njit被证明是一个令人难以置信的包装器,将相同函数的执行次数增加了近 35 倍!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM