[英]Computing quick convex hull using Numba
我遇到了使用 Numpy 实现计算 2d 点的凸包的这个很好的实现。 我希望能够@njit 这个函数以在我的其他 Numba jitted 代码中使用它。 但是我无法修改它,无法运行,因为它使用递归和不受支持的 Numba 功能? 谁能帮我重写这个?
import numpy as np
from numba import njit
def process(S, P, a, b):
signed_dist = np.cross(S[P] - S[a], S[b] - S[a])
K = [i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b]
if len(K) == 0:
return (a, b)
c = max(zip(signed_dist, P))[1]
return process(S, K, a, c)[:-1] + process(S, K, c, b)
def quickhull_2d(S: np.ndarray) -> np.ndarray:
a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
max_index = np.argmax(S[:,0])
max_element = S[max_index]
return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]
示例数据输入和输出
points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)
print(ch)
[0, 4, 1, 3]
print(points[ch])
[[0. 0.]
[1. 0.]
[1. 1.]
[0. 1.]]
要使用 Numba,此代码中存在许多问题。
首先,在 Numba 中返回可变大小的元组是不可能的,因为元组的类型隐含地包含它的大小。 元组基本上是一种结构化类型,而不是列表。 有关此问题的更多信息,请参阅这篇文章和这篇文章。 解决方案基本上是返回一个列表(慢)或一个数组(快)。
此外,参数的类型从一个函数变为另一个函数。 实际上,在quickhull_2d
中调用process
时将P
定义为 Numpy 数组,然后从process
本身调用并将P
定义为列表。 列表和数组是完全不同的东西。 最好在 Numba 中尽可能使用数组,除非您使用列表来添加未知数量的项目(既不小也不有限)。
此外, max(zip(signed_dist, P))[1]
显然不受 Numba 支持,而且它无论如何都不是很有效(对于 Numpy 代码也不是惯用的)。 应该P[np.argmax(signed_dist)]
。
此外, np.cross
似乎也不支持一般情况,您目前需要使用cross2d
(来自numba.np.extensions
)。
最后,当你像这样使用递归函数时,最好指定参数的输入类型,以避免出现奇怪的错误。 这可以通过签名字符串来完成。
结果代码是:
import numpy as np
from numba import njit
from numba.np.extensions import cross2d
@njit('(float64[:,:], int64[:], int64, int64)')
def process(S, P, a, b):
signed_dist = cross2d(S[P] - S[a], S[b] - S[a])
K = np.array([i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b], dtype=np.int64)
if len(K) == 0:
return [a, b]
c = P[np.argmax(signed_dist)]
return process(S, K, a, c)[:-1] + process(S, K, c, b)
@njit('(float64[:,:],)')
def quickhull_2d(S: np.ndarray) -> np.ndarray:
a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
max_index = np.argmax(S[:,0])
max_element = S[max_index]
return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]
points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)
print(ch) # print [0, 4, 1, 3]
请注意,编译时间很慢,执行时间应该不会很大。 这是由于列表(以及用于运行时性能的临时数组)。 下一步就是简单地使用数组。 坏消息是concatenate
不支持连接(因为一般情况下不容易实现,尽管具体情况很简单)。 您可以创建一个新数组并复制每个部分(或者更好:您可以预分配一个数组并在递归调用期间对其进行切片)。
也不是说任何递归函数都可以使用手动堆栈转换为非递归函数。 话虽如此,它可能会变慢并使代码更加冗长。 不过,这种方法有一些好处:当递归很深时,它可以避免堆栈溢出,并且如果函数被重写,它可能会更快,因为尾调用优化不会将函数调用之一堆叠起来。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.