[英]Find N largest elements in a list with a minimum distance
我想从列表中提取 N 个最大的元素,但我想要任何两个元素x[i]
和x[j]
, abs(ij) > min_distance
。
scipy.signal.find_peaks(x, distance=min_distance)
提供了这个功能。 但是,我需要重复此操作数百万次,并且我试图加快操作速度。
我注意到find_peaks
不接受参数N
来指示您要提取的峰值数量。 它也不允许从最大到最小返回峰值,需要额外调用l.sort()
和l = l[:N]
。
我试图编写一个惰性排序器,它只查找 N 个最大的元素,而不对列表的其余部分进行排序。
import heapq
def new_find_peaks(x, N, min_distance=0):
x = enumerate(x)
x = [(-val,i) for (i,val) in x]
heapq.heapify(x)
val, pos = heapq.heappop(x)
peaks = [(-val, pos,)]
while len(peaks)<N:
while True:
val, pos = heapq.heappop(x)
d = min([abs(pos - pos_i) for _,pos_i in peaks])
if d >= min_distance:
break
peaks.append((-val, pos,))
return map(list, zip(*peaks)) #Transpose peaks into 2 lists
然而,这仍然比find_peaks
慢 20 倍,可能是由于find_peaks
CPython 实现。 另外,我注意到几乎一半的时间都花在了
x = [(-val,i) for (i,val) in x]
你有什么更好的主意来加速这个操作吗?
--- 最小的可重复示例 ---
例如:
x = [-8.11, -7.33, -7.48, -5.77, -8.73, -8.73, -7.02, -7.02,
-7.80, -10.92, -9.36, -9.83, -10.14, -10.77, -11.23, -9.20,
-9.52, -9.67, -11.23, -9.98, -7.95, -9.83, -8.89, -7.33,
-4.20, -4.05, -6.70, -7.02, -9.20, -9.21]
new_find_peaks(x, N=3, min_distance=5)
>> [[-4.05, -5.77, -7.8], [25, 3, 8]]
请注意, x[24]
为 -4.2,但由于x[25]
更大且25-24 < min_distance
,因此将其丢弃。 另请注意, x[8]
不是真正的峰值,因为x[7]
更大,但由于与x[3]
的距离而被丢弃。 这是预期的行为。
为了加快计算速度,我写了第二个不使用heapq的算法。 这避免了在堆队列中重塑整个列表。
新算法看起来像这样
from heapq import nlargest
def find_peaks_fast(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
该算法扫描列表一次并提取之前 N 个样本和之后 N 个样本中较高的所有样本。 然后将它们存储在一个列表中,使用 heapq.nlargest 从中仅提取 nlargest 元素
就其本身而言,这将执行时间降低到 3.7 毫秒。 很快,但仍然比 scipy 的 find_peaks 慢近 4 倍。
但是,这可以使用包numba
进行更改。 这旨在即时“编译”python 代码并执行编译版本以提高速度。 它改善了很多!
from numba import njit
from heapq import nlargest
from numba.errors import NumbaPendingDeprecationWarning
import warnings
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
@njit
def find_peaks_fast(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
并测试它
from scipy.signal import find_peaks
from timeit import repeat
from numpy.random import randn
from numba import njit
from heapq import nlargest
def new_find_peaks2(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
@njit
def new_find_peaks2_jit(x, N, min_distance=0):
peaks = []
last_i = 0
last_peak = x[0]
for i, val in enumerate(x[1:], 1):
if i - last_i == min_distance:
# Store peak
peaks.append(last_peak)
# Store the new item and move on
last_peak = val
last_i = i
elif val > last_peak:
last_peak = val
last_i = i
return nlargest(N,peaks)
num = 500
rep = 10
N = 20
x = randn(20000)
sep = 10
code1 = '''
i_pks, _ = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks = pks[:N]
'''
code2 = '''
_ = new_find_peaks2(x, N=N, min_distance=sep)
'''
code2_jit = '''
_ = new_find_peaks2_jit(x, N=N, min_distance=sep)
'''
i_pks, _ = find_peaks(x, distance=sep)
pks = x[i_pks]
pks[::-1].sort()
pks1 = pks[:N]
pks2 = new_find_peaks2(x, N=N, min_distance=sep)
print(pks1==pks2)
t = min(repeat(stmt=code1, globals=globals(), number=num, repeat=rep))/num
print(f'np.find_peaks:\t\t{t*1000} [ms]')
t = min(repeat(stmt=code2, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2:\t{t*1000} [ms]')
t = min(repeat(stmt=code2_jit, globals=globals(), number=num, repeat=rep))/num
print(f'new_find_peaks2_jit:\t{t*1000} [ms]')
导致结果:
np.find_peaks: 1.1234994470141828 [ms]
new_find_peaks2: 3.565517600043677 [ms]
new_find_peaks2_jit: 0.10387242998695001 [ms]
那是一个 x10 的加速!
结论:
numba.njit
被证明是一个令人难以置信的包装器,将相同函数的执行次数增加了近 35 倍!
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.