[英]Find N largest elements in a list with a minimum distance
我想從列表中提取 N 個最大的元素,但我想要任何兩個元素x[i]
和x[j]
, abs(ij) > min_distance
。
scipy.signal.find_peaks(x, distance=min_distance)
提供了這個功能。 但是,我需要重復此操作數百萬次,並且我試圖加快操作速度。
我注意到find_peaks
不接受參數N
來指示您要提取的峰值數量。 它也不允許從最大到最小返回峰值,需要額外調用l.sort()
和l = l[:N]
。
我試圖編寫一個惰性排序器,它只查找 N 個最大的元素,而不對列表的其余部分進行排序。
import heapq
def new_find_peaks(x, N, min_distance=0):
x = enumerate(x)
x = [(-val,i) for (i,val) in x]
heapq.heapify(x)
val, pos = heapq.heappop(x)
peaks = [(-val, pos,)]
while len(peaks)<N:
while True:
val, pos = heapq.heappop(x)
d = min([abs(pos - pos_i) for _,pos_i in peaks])
if d >= min_distance:
break
peaks.append((-val, pos,))
return map(list, zip(*peaks)) #Transpose peaks into 2 lists
然而,這仍然比find_peaks
慢 20 倍,可能是由於find_peaks
CPython 實現。 另外,我注意到幾乎一半的時間都花在了
x = [(-val,i) for (i,val) in x]
你有什么更好的主意來加速這個操作嗎?
--- 最小的可重復示例 ---
例如:
x = [-8.11, -7.33, -7.48, -5.77, -8.73, -8.73, -7.02, -7.02,
-7.80, -10.92, -9.36, -9.83, -10.14, -10.77, -11.23, -9.20,
-9.52, -9.67, -11.23, -9.98, -7.95, -9.83, -8.89, -7.33,
-4.20, -4.05, -6.70, -7.02, -9.20, -9.21]
new_find_peaks(x, N=3, min_distance=5)
>> [[-4.05, -5.77, -7.8], [25, 3, 8]]
請注意, x[24]
為 -4.2,但由於x[25]
更大且25-24 < min_distance
,因此將其丟棄。 另請注意, x[8]
不是真正的峰值,因為x[7]
更大,但由於與x[3]
的距離而被丟棄。 這是預期的行為。
為了加快計算速度,我寫了第二個不使用heapq的算法。 這避免了在堆隊列中重塑整個列表。
新算法看起來像這樣
from heapq import nlargest
def find_peaks_fast(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
該算法掃描列表一次並提取之前 N 個樣本和之后 N 個樣本中較高的所有樣本。 然后將它們存儲在一個列表中,使用 heapq.nlargest 從中僅提取 nlargest 元素
就其本身而言,這將執行時間降低到 3.7 毫秒。 很快,但仍然比 scipy 的 find_peaks 慢近 4 倍。
但是,這可以使用包numba
進行更改。 這旨在即時“編譯”python 代碼並執行編譯版本以提高速度。 它改善了很多!
from numba import njit
from heapq import nlargest
from numba.errors import NumbaPendingDeprecationWarning
import warnings
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
@njit
def find_peaks_fast(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
並測試它
from scipy.signal import find_peaks
from timeit import repeat
from numpy.random import randn
from numba import njit
from heapq import nlargest
def new_find_peaks2(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
@njit
def new_find_peaks2_jit(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
num = 500
rep = 10
N = 20
x = randn(20000)
sep = 10
code1 = '''
i_pks, _ = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks = pks[:N]
'''
code2 = '''
_ = new_find_peaks2(x, N=N, min_distance=sep)
'''
code2_jit = '''
_ = new_find_peaks2_jit(x, N=N, min_distance=sep)
'''
i_pks, _ = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks1 = pks[:N]
pks2 = new_find_peaks2(x, N=N, min_distance=sep)
print(pks1==pks2)
t = min(repeat(stmt=code1, globals=globals(), number=num, repeat=rep))/num
print(f'np.find_peaks:\t\t{t*1000} [ms]')
t = min(repeat(stmt=code2, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2:\t{t*1000} [ms]')
t = min(repeat(stmt=code2_jit, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2_jit:\t{t*1000} [ms]')
導致結果:
np.find_peaks: 1.1234994470141828 [ms]
new_find_peaks2: 3.565517600043677 [ms]
new_find_peaks2_jit: 0.10387242998695001 [ms]
那是一個 x10 的加速!
結論:
numba.njit
被證明是一個令人難以置信的包裝器,將相同函數的執行次數增加了近 35 倍!
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.