简体   繁体   中英

Computing quick convex hull using Numba

I came across to this nice implementation of computing convex hull of 2d points using Numpy implementation. I would like to be able to @njit this function to use it inside my other Numba jitted code. However I'm not able to modify it, to run, as it uses recursion, and unsupported Numba features? Can anybody help me to rewrite this?

import numpy as np
from numba import njit

def process(S, P, a, b):
    signed_dist = np.cross(S[P] - S[a], S[b] - S[a])
    K = [i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b]

    if len(K) == 0:
        return (a, b)

    c = max(zip(signed_dist, P))[1]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

Example data input and output

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch)
[0, 4, 1, 3]

print(points[ch])
[[0. 0.]
 [1. 0.]
 [1. 1.]
 [0. 1.]]

There are many issues in this code for Numba to be used.

First of all, returning variable-sized tuples is not possible in Numba because the type of a tuple implicitly includes its size. A tuple is basically a structured type and not a list. See this post and this one for more information about this issue. The solution is basically to return a list (slow) or an array (fast).

Moreover, the type of the parameters change from one function to another . Indeed, process is called in quickhull_2d with a P defined as a Numpy array and then called from process itself with P defined as a list. List and array are completely different things. It is better to use array when possible in Numba unless you use a list to add an unknown number of items (not small nor bounded).

Additionally, max(zip(signed_dist, P))[1] is apparently unsupported by Numba and it is not very efficient anyway (nor idiomatic for a Numpy code). P[np.argmax(signed_dist)] should be used instead.

Furthermore, np.cross also does not seems supported for the general case and you need to currently use cross2d instead (from numba.np.extensions ).

Finally, when you use recursive function like this, it is better to specify the input type of the parameters so to avoid weird errors. This can be done thanks to a signature string.

The resulting code is:

import numpy as np
from numba import njit
from numba.np.extensions import cross2d

@njit('(float64[:,:], int64[:], int64, int64)')
def process(S, P, a, b):
    signed_dist = cross2d(S[P] - S[a], S[b] - S[a])
    K = np.array([i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b], dtype=np.int64)

    if len(K) == 0:
        return [a, b]

    c = P[np.argmax(signed_dist)]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

@njit('(float64[:,:],)')
def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch) # print [0, 4, 1, 3]

Note that the compilation time is slow and the execution time should not be great. This is due to lists (and so temporary array for the runtime performance). The next step is simply to use arrays. The bad news is that concatenate is not supported by Numba (because the general case is not easy to implement though specific case are trivial). You can create a new array and copy each part (or even better: you can preallocate an array and slice it during the recursive calls).

Also not that any recursive function can be transformed to a non-recursive function using a manual stack. That being said, it may be slower and make the code more verbose. There are some benefits to this approach though: it avoid stack overflow when the recursion is deep and it may be faster if the function is rewritten so not to stack one of the function call thanks to tail call optimization .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM