繁体   English   中英

使用 Numba 快速计算凸包

[英]Computing quick convex hull using Numba

我遇到了使用 Numpy 实现计算 2d 点的凸包的这个很好的实现。 我希望能够@njit 这个函数以在我的其他 Numba jitted 代码中使用它。 但是我无法修改它,无法运行,因为它使用递归和不受支持的 Numba 功能? 谁能帮我重写这个?

import numpy as np
from numba import njit

def process(S, P, a, b):
    signed_dist = np.cross(S[P] - S[a], S[b] - S[a])
    K = [i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b]

    if len(K) == 0:
        return (a, b)

    c = max(zip(signed_dist, P))[1]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

示例数据输入和输出

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch)
[0, 4, 1, 3]

print(points[ch])
[[0. 0.]
 [1. 0.]
 [1. 1.]
 [0. 1.]]

要使用 Numba,此代码中存在许多问题。

首先,在 Numba 中返回可变大小的元组是不可能的,因为元组的类型隐含地包含它的大小。 元组基本上是一种结构化类型,而不是列表。 有关此问题的更多信息,请参阅这篇文章这篇文章。 解决方案基本上是返回一个列表(慢)或一个数组(快)。

此外,参数的类型从一个函数变为另一个函数 实际上,在quickhull_2d中调用process时将P定义为 Numpy 数组,然后从process本身调用并将P定义为列表。 列表和数组是完全不同的东西。 最好在 Numba 中尽可能使用数组,除非您使用列表来添加未知数量的项目(既不小也不有限)。

此外, max(zip(signed_dist, P))[1]显然不受 Numba 支持,而且它无论如何都不是很有效(对于 Numpy 代码也不是惯用的)。 应该P[np.argmax(signed_dist)]

此外, np.cross似乎也不支持一般情况,您目前需要使用cross2d (来自numba.np.extensions )。

最后,当你像这样使用递归函数时,最好指定参数的输入类型,以避免出现奇怪的错误。 这可以通过签名字符串来完成。

结果代码是:

import numpy as np
from numba import njit
from numba.np.extensions import cross2d

@njit('(float64[:,:], int64[:], int64, int64)')
def process(S, P, a, b):
    signed_dist = cross2d(S[P] - S[a], S[b] - S[a])
    K = np.array([i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b], dtype=np.int64)

    if len(K) == 0:
        return [a, b]

    c = P[np.argmax(signed_dist)]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

@njit('(float64[:,:],)')
def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch) # print [0, 4, 1, 3]

请注意,编译时间很慢,执行时间应该不会很大。 这是由于列表(以及用于运行时性能的临时数组)。 下一步就是简单地使用数组。 坏消息是concatenate不支持连接(因为一般情况下不容易实现,尽管具体情况很简单)。 您可以创建一个新数组并复制每个部分(或者更好:您可以预分配一个数组并在递归调用期间对其进行切片)。

也不是说任何递归函数都可以使用手动堆栈转换为非递归函数。 话虽如此,它可能会变慢并使代码更加冗长。 不过,这种方法有一些好处:当递归很深时,它可以避免堆栈溢出,并且如果函数被重写,它可能会更快,因为尾调用优化不会将函数调用之一堆叠起来。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM