[英]Fastest way to check which interval index a value is
我有一個像這樣的向量:
intervals = [6, 7, 8, 9, 10, 11] #always regular
我想檢查一個值是哪個間隔索引。 例如: 8.5
所在的間隔的索引是3
。
#Interval : index
6 -> 7 : 1
7 -> 8 : 2
8 -> 9 : 3
9 -> 10 : 4
10 -> 11 : 5
所以我做了這段代碼:
from numpy import *
N = 8000
data = random.random(N)
step_number = 50
max_value = max(data)
min_value = min(data)
step_length = (max_value - min_value)/step_number
intervals = arange(min_value + step_length, max_value + step_length, step_length )
for x in data:
for index in range(len(intervals)):
if x < intervals[index]:
print("That's the index", index)
break
這段代碼可以正常工作,但是速度太慢了,我想我在這些循環中浪費時間。 有辦法更快地檢查嗎? 也許使用一些numpy特殊功能為我檢查一下...
根據您要如何處理端點,有bisect.bisect_left
和bisect.bisect_right
:
>>> import bisect
>>> intervals = [6, 7, 8, 9, 10, 11]
>>> for n in (6, 6.1, 6.2, 6.5, 6.8, 7):
... print bisect.bisect_left(intervals, n)
...
0
1
1
1
1
1
>>> for n in (6, 6.1, 6.2, 6.5, 6.8, 7):
... print bisect.bisect_right(intervals, n)
...
1
1
1
1
1
2
Numpy使用searchsorted
方法實現同一件事。
>>> import numpy as np
>>> np.searchsorted(intervals, (6, 6.1, 6.2, 6.5, 6.8, 7), side='left')
array([0, 1, 1, 1, 1, 1])
>>> np.searchsorted(intervals, (6, 6.1, 6.2, 6.5, 6.8, 7), side='right')
array([1, 1, 1, 1, 1, 2])
而且,當然,如果間隔相等,則可以執行以下操作:
>>> for n in (6, 6.1, 6.2, 6.5, 6.8, 7):
... iwidth = intervals[1] - intervals[0]
... print np.ceil((n - intervals[0]) / iwidth)
...
0.0
1.0
1.0
1.0
1.0
1.0
正如其他人提到的,如果間隔不規則,請使用二等分搜索(例如np.searchsorted
和/或np.digitize
)。
但是,在您已聲明自己總是有固定間隔的特定情況下,還可以執行以下操作:
import numpy as np
intervals = [6, 7, 8, 9, 10, 11]
vals = np.array([8.5, 6.2, 9.8])
dx = intervals[1] - intervals[0]
x0 = intervals[0]
i = np.ceil((vals - x0) / dx).astype(int)
或者,以您的示例代碼為基礎:
import numpy as np
N = 8000
num_intervals = 50
data = np.random.random(N)
intervals = np.linspace(data.min(), data.max(), num_intervals)
x0 = intervals[0]
dx = intervals[1] - intervals[0]
i = np.ceil((data - x0) / dx).astype(int)
這將比對大型數組進行二進制搜索快得多。
只要列表已排序,就可以使用bisect庫獲取插入索引。
index = bisect.bisect_left(intervals, 8.5)
僅使用numpy:
import numpy as np
intervals = np.array([6, 7, 8, 9, 10, 11])
val = (intervals > 8.5)
print val.argmax()
我會去找一個函數:
def f_idx(f_list, number):
for idx,item in enumerate(f_list):
if item>number:
return idx
return len(f_list)
在一個襯里中:
result = [idx for idx,value in enumerate(intervals) if value>number][0] if intervals[-1]>number else len(intervals)
使用numpy.digitize
:
http://docs.scipy.org/doc/numpy-1.10.0/reference/generation/numpy.digitize.html#numpy-digitize
>>> import numpy as np
>>> intervals = [6, 7, 8, 9, 10, 11]
>>> data = [3.5, 6.3, 9.4, 11.5, 8.5]
>>> np.digitize(data, bins=interval)
array([0, 1, 4, 6, 3])
0
是下溢, len(intervals)
是上溢
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.