簡體   English   中英

在列表中查找距離最小的 N 個最大元素

[英]Find N largest elements in a list with a minimum distance

我想從列表中提取 N 個最大的元素,但我想要任何兩個元素x[i]x[j]abs(ij) > min_distance

scipy.signal.find_peaks(x, distance=min_distance)提供了這個功能。 但是,我需要重復此操作數百萬次,並且我試圖加快操作速度。

我注意到find_peaks不接受參數N來指示您要提取的峰值數量。 它也不允許從最大到最小返回峰值,需要額外調用l.sort()l = l[:N]

我試圖編寫一個惰性排序器,它只查找 N 個最大的元素,而不對列表的其余部分進行排序。

根據此處獲得的結果,我選擇了heapq 這是我的嘗試:

import heapq

def new_find_peaks(x, N, min_distance=0):
    x = enumerate(x)

    x = [(-val,i) for (i,val) in x]
    heapq.heapify(x)

    val, pos = heapq.heappop(x)
    peaks = [(-val, pos,)]

    while len(peaks)<N:

        while True:
            val, pos = heapq.heappop(x)
            d = min([abs(pos - pos_i) for _,pos_i in peaks])
            if d >= min_distance:
                break

        peaks.append((-val, pos,))

    return map(list, zip(*peaks)) #Transpose peaks into 2 lists

然而,這仍然比find_peaks慢 20 倍,可能是由於find_peaks CPython 實現。 另外,我注意到幾乎一半的時間都花在了

x = [(-val,i) for (i,val) in x]

你有什么更好的主意來加速這個操作嗎?

--- 最小的可重復示例 ---

例如:

x = [-8.11, -7.33, -7.48, -5.77, -8.73, -8.73, -7.02, -7.02,
 -7.80, -10.92, -9.36, -9.83, -10.14, -10.77, -11.23, -9.20,
 -9.52, -9.67, -11.23, -9.98, -7.95, -9.83, -8.89, -7.33,
 -4.20, -4.05, -6.70, -7.02, -9.20, -9.21]

new_find_peaks(x, N=3, min_distance=5)

>> [[-4.05, -5.77, -7.8], [25, 3, 8]]

X

請注意, x[24]為 -4.2,但由於x[25]更大且25-24 < min_distance ,因此將其丟棄。 另請注意, x[8]不是真正的峰值,因為x[7]更大,但由於與x[3]的距離而被丟棄。 這是預期的行為。

在 python 中改進你的代碼可能會給你一些改進,但由於你的代碼看起來很干凈而且算法的想法很合理,我認為你不會用 python 方法擊敗find_peaks

因此我建議你用一種更接近金屬的語言編寫你自己的庫,如果你需要 python 中的結果,請編寫你自己的 python 包裝器。 例如,您可以使用 Swift。 是 Swift 中堆隊列的一個實現,在這里您可以找到描述的一種與 python 接口的方法。

連接點留作練習。 ;)

為了加快計算速度,我寫了第二個不使用heapq的算法。 這避免了在堆隊列中重塑整個列表。

新算法看起來像這樣

from heapq import nlargest

def find_peaks_fast(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)

該算法掃描列表一次並提取之前 N 個樣本和之后 N 個樣本中較高的所有樣本。 然后將它們存儲在一個列表中,使用 heapq.nlargest 從中僅提取 nlargest 元素

就其本身而言,這將執行時間降低到 3.7 毫秒。 很快,但仍然比 scipy 的 find_peaks 慢近 4 倍。

但是,這可以使用包numba進行更改。 這旨在即時“編譯”python 代碼並執行編譯版本以提高速度。 它改善了很多

from numba import njit
from heapq import nlargest

from numba.errors import NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

@njit
def find_peaks_fast(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
        last_i = i

    elif val > last_peak:
        last_peak = val
        last_i = i

return nlargest(N,peaks)

並測試它

from scipy.signal import find_peaks

from timeit import repeat
from numpy.random import randn

from numba import njit
from heapq import nlargest


def new_find_peaks2(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)

@njit
def new_find_peaks2_jit(x, N, min_distance=0):

    peaks = []
    last_i = 0
    last_peak = x[0]

    for i, val in enumerate(x[1:], 1):

        if i - last_i == min_distance:
            # Store peak
            peaks.append(last_peak)

            # Store the new item and move on
            last_peak = val
            last_i = i

        elif val > last_peak:
            last_peak = val
            last_i = i

    return nlargest(N,peaks)


num = 500
rep = 10

N = 20
x = randn(20000)
sep = 10

code1 = '''
i_pks, _  = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks = pks[:N]
'''

code2 = '''
_ = new_find_peaks2(x, N=N, min_distance=sep)
'''

code2_jit = '''
_ = new_find_peaks2_jit(x, N=N, min_distance=sep)
'''

i_pks, _  = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks1 = pks[:N]
pks2 = new_find_peaks2(x, N=N, min_distance=sep)

print(pks1==pks2)

t = min(repeat(stmt=code1, globals=globals(), number=num, repeat=rep))/num
print(f'np.find_peaks:\t\t{t*1000} [ms]')

t = min(repeat(stmt=code2, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2:\t{t*1000} [ms]')

t = min(repeat(stmt=code2_jit, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2_jit:\t{t*1000} [ms]')

導致結果:

np.find_peaks:          1.1234994470141828 [ms]
new_find_peaks2:      3.565517600043677 [ms]
new_find_peaks2_jit:  0.10387242998695001 [ms]

那是一個 x10 的加速!

結論:

  • 該算法可以加速
  • numba.njit被證明是一個令人難以置信的包裝器,將相同函數的執行次數增加了近 35 倍!

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM