I was having fun writing a quick select algorithm using numba
and wanted to share the results.
Consider the array x
np.random.seed([3,1415])
x = np.random.permutation(np.arange(10))
x
array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])
What is the quickest way to pull the smallest n elements.
I've tried
np.partition
np.partition(x, 5)[:5]
array([0, 1, 2, 3, 4])
pd.Series.nsmallest
pd.Series(x).nsmallest(5).values
array([0, 1, 2, 3, 4])
Update
@user2357112 pointed out in the comments that my function was manipulating inplace. Turn's out that's where my performance boost was coming from. So in the end, we have very similar performance from a crude implementation of quickselect
with numba
. Still nothing to sneeze at but not what I was hoping for.
As I said in the question, I was messing around with numba
and wanted to share what I've found.
Note that I've imported njit
and not jit
. This is a decorator that automatically prevents itself from falling back on native python objects. Meaning that when it does its speed up thing, it will only use things it can actually speed up. This in turn means that my function fails a lot while I figure out what is allowed and what isn't allowed.
So far, it is my opinion that writing things with numba
s jit
and njit
are finicky and difficult but kind of worth it when you get to see a decent performance payoff.
This is my quick and dirty quickselect
function
import numpy as np
from numba import njit
import pandas as pd
import numexpr as ne
@njit
def rselect(a, k):
n = len(a)
if n <= 1:
return a
elif k > n:
return a
else:
p = np.random.randint(n)
pivot = a[p]
a[0], a[p] = a[p], a[0]
i = j = 1
while j < n:
if a[j] < pivot:
a[j], a[i] = a[i], a[j]
i += 1
j += 1
a[i-1], a[0] = a[0], a[i-1]
if i - 1 <= k <= i:
return a[:k]
elif k > i:
return np.concatenate((a[:i], rselect(a[i:], k - i)))
else:
return rselect(a[:i-1], k)
You'll notice that it returns the same elements as the methods in the question.
rselect(x, 5)
array([2, 1, 0, 3, 4])
What about speed?
def nsmall_np(x, n):
return np.partition(x, n)[:n]
def nsmall_pd(x, n):
pd.Series(x).nsmallest().values
def nsmall_pir(x, n):
return rselect(x.copy(), n)
from timeit import timeit
results = pd.DataFrame(
index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'),
columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method')
)
for i in results.index:
x = np.random.rand(i)
n = i // 2
for j in results.columns:
stmt = '{}(x, n)'.format(j)
setp = 'from __main__ import {}, x, n'.format(j)
results.set_value(
i, j, timeit(stmt, setp, number=1000)
)
results
Method nsmall_np nsmall_pd nsmall_pir
Size
100 0.003873 0.336693 0.002941
1000 0.007683 1.170193 0.011460
3000 0.016083 0.309765 0.029628
6000 0.050026 0.346420 0.059591
10000 0.106036 0.435710 0.092076
100000 1.064301 2.073206 0.936986
1000000 11.864195 27.447762 12.755983
results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6))
In general I wouldn't recommend trying to beat NumPy. It's rare that one can compete (for long arrays) and it's even rarer to find a faster implementation. And even if it's faster it's probably not more than 2 times faster. So it's seldom worth it.
However I recently tried to do something like this myself, so I can actually share some interesting results.
I didn't thought this up myself. I based my approach on numbas (re-)implementation of np.median
. They probably knew what they were doing.
What I ended up with was:
import numba as nb
import numpy as np
@nb.njit
def _partition(A, low, high):
"""copied from numba source code"""
mid = (low + high) >> 1
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[high], A[mid] = A[mid], A[high]
i = low
for j in range(low, high):
if A[j] <= pivot:
A[i], A[j] = A[j], A[i]
i += 1
A[i], A[high] = A[high], A[i]
return i
@nb.njit
def _select_lowest(arry, k, low, high):
"""copied from numba source code, slightly changed"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[:k]
@nb.njit
def _nlowest_inner(temp_arry, n, idx):
"""copied from numba source code, slightly changed"""
low = 0
high = n - 1
return _select_lowest(temp_arry, idx, low, high)
@nb.njit
def nlowest(a, idx):
"""copied from numba source code, slightly changed"""
temp_arry = a.flatten() # does a copy! :)
n = temp_arry.shape[0]
return _nlowest_inner(temp_arry, n, idx)
And I included some warm-up calls before doing the timings. The warm-up is so the compilation time isn't included in the timings:
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
Having a (much) slower computer I changed the number of elements and the number of repetitions a bit. But the results seem to indicate that I (well, the numba developers did) have beaten NumPy:
results = pd.DataFrame(
index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
for i in results.index:
x = np.random.rand(i)
n = i // 2
for j in results.columns:
stmt = '{}(x, n)'.format(j)
setp = 'from __main__ import {}, x, n'.format(j)
results.set_value(i, j, timeit(stmt, setp, number=100))
print(results)
Method nsmall_np nsmall_pd nsmall_pir nlowest
Size
100 0.00343059 0.561372 0.00190855 0.000935566
500 0.00428461 1.79398 0.00326862 0.00187225
1000 0.00560669 3.36844 0.00432595 0.00364284
5000 0.0132515 0.305471 0.0142569 0.0108995
10000 0.0255161 0.340215 0.024847 0.0248285
50000 0.105937 0.543337 0.150277 0.118294
100000 0.2452 0.835571 0.333697 0.248473
500000 1.75214 3.50201 2.20235 1.44085
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.