简体   繁体   中英

What is the fastest way to select the smallest n elements from an array?

I was having fun writing a quick select algorithm using numba and wanted to share the results.

Consider the array x

np.random.seed([3,1415])
x = np.random.permutation(np.arange(10))
x

array([9, 4, 5, 1, 7, 6, 8, 3, 2, 0])

What is the quickest way to pull the smallest n elements.

I've tried
np.partition

np.partition(x, 5)[:5]

array([0, 1, 2, 3, 4])

pd.Series.nsmallest

pd.Series(x).nsmallest(5).values

array([0, 1, 2, 3, 4])

Update
@user2357112 pointed out in the comments that my function was manipulating inplace. Turn's out that's where my performance boost was coming from. So in the end, we have very similar performance from a crude implementation of quickselect with numba . Still nothing to sneeze at but not what I was hoping for.


As I said in the question, I was messing around with numba and wanted to share what I've found.

Note that I've imported njit and not jit . This is a decorator that automatically prevents itself from falling back on native python objects. Meaning that when it does its speed up thing, it will only use things it can actually speed up. This in turn means that my function fails a lot while I figure out what is allowed and what isn't allowed.

So far, it is my opinion that writing things with numba s jit and njit are finicky and difficult but kind of worth it when you get to see a decent performance payoff.

This is my quick and dirty quickselect function

import numpy as np
from numba import njit
import pandas as pd
import numexpr as ne

@njit
def rselect(a, k):
    n = len(a)
    if n <= 1:
        return a
    elif k > n:
        return a
    else:
        p = np.random.randint(n)
        pivot = a[p]
        a[0], a[p] = a[p], a[0]
        i = j = 1
        while j < n:
            if a[j] < pivot:
                a[j], a[i] = a[i], a[j]
                i += 1
            j += 1
        a[i-1], a[0] = a[0], a[i-1]
        if i - 1 <= k <= i:
            return a[:k]
        elif k > i:
            return np.concatenate((a[:i], rselect(a[i:], k - i)))
        else:
            return rselect(a[:i-1], k)

You'll notice that it returns the same elements as the methods in the question.

rselect(x, 5)

array([2, 1, 0, 3, 4])

What about speed?

def nsmall_np(x, n):
    return np.partition(x, n)[:n]

def nsmall_pd(x, n):
    pd.Series(x).nsmallest().values

def nsmall_pir(x, n):
    return rselect(x.copy(), n)


from timeit import timeit


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method')
)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(
            i, j, timeit(stmt, setp, number=1000)
        )

results

Method   nsmall_np  nsmall_pd  nsmall_pir
Size                                     
100       0.003873   0.336693    0.002941
1000      0.007683   1.170193    0.011460
3000      0.016083   0.309765    0.029628
6000      0.050026   0.346420    0.059591
10000     0.106036   0.435710    0.092076
100000    1.064301   2.073206    0.936986
1000000  11.864195  27.447762   12.755983

results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6))

[1]:https://i.stack.imgur.com/hKo2o .png

In general I wouldn't recommend trying to beat NumPy. It's rare that one can compete (for long arrays) and it's even rarer to find a faster implementation. And even if it's faster it's probably not more than 2 times faster. So it's seldom worth it.

However I recently tried to do something like this myself, so I can actually share some interesting results.

I didn't thought this up myself. I based my approach on numbas (re-)implementation of np.median . They probably knew what they were doing.

What I ended up with was:

import numba as nb
import numpy as np

@nb.njit
def _partition(A, low, high):
    """copied from numba source code"""
    mid = (low + high) >> 1
    if A[mid] < A[low]:
        A[low], A[mid] = A[mid], A[low]
    if A[high] < A[mid]:
        A[high], A[mid] = A[mid], A[high]
        if A[mid] < A[low]:
            A[low], A[mid] = A[mid], A[low]
    pivot = A[mid]

    A[high], A[mid] = A[mid], A[high]

    i = low
    for j in range(low, high):
        if A[j] <= pivot:
            A[i], A[j] = A[j], A[i]
            i += 1

    A[i], A[high] = A[high], A[i]
    return i

@nb.njit
def _select_lowest(arry, k, low, high):
    """copied from numba source code, slightly changed"""
    i = _partition(arry, low, high)
    while i != k:
        if i < k:
            low = i + 1
            i = _partition(arry, low, high)
        else:
            high = i - 1
            i = _partition(arry, low, high)
    return arry[:k]

@nb.njit
def _nlowest_inner(temp_arry, n, idx):
    """copied from numba source code, slightly changed"""
    low = 0
    high = n - 1
    return _select_lowest(temp_arry, idx, low, high)

@nb.njit
def nlowest(a, idx):
    """copied from numba source code, slightly changed"""
    temp_arry = a.flatten()  # does a copy! :)
    n = temp_arry.shape[0]
    return _nlowest_inner(temp_arry, n, idx)

And I included some warm-up calls before doing the timings. The warm-up is so the compilation time isn't included in the timings:

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

Having a (much) slower computer I changed the number of elements and the number of repetitions a bit. But the results seem to indicate that I (well, the numba developers did) have beaten NumPy:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)

rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)

for i in results.index:
    x = np.random.rand(i)
    n = i // 2
    for j in results.columns:
        stmt = '{}(x, n)'.format(j)
        setp = 'from __main__ import {}, x, n'.format(j)
        results.set_value(i, j, timeit(stmt, setp, number=100))

print(results)

Method   nsmall_np nsmall_pd  nsmall_pir      nlowest
Size                                                 
100     0.00343059  0.561372  0.00190855  0.000935566
500     0.00428461   1.79398  0.00326862   0.00187225
1000    0.00560669   3.36844  0.00432595   0.00364284
5000     0.0132515  0.305471   0.0142569    0.0108995
10000    0.0255161  0.340215    0.024847    0.0248285
50000     0.105937  0.543337    0.150277     0.118294
100000      0.2452  0.835571    0.333697     0.248473
500000     1.75214   3.50201     2.20235      1.44085

在此输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM