簡體   English   中英

使用 Numba 快速計算凸包

[英]Computing quick convex hull using Numba

我遇到了使用 Numpy 實現計算 2d 點的凸包的這個很好的實現。 我希望能夠@njit 這個函數以在我的其他 Numba jitted 代碼中使用它。 但是我無法修改它,無法運行,因為它使用遞歸和不受支持的 Numba 功能? 誰能幫我重寫這個?

import numpy as np
from numba import njit

def process(S, P, a, b):
    signed_dist = np.cross(S[P] - S[a], S[b] - S[a])
    K = [i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b]

    if len(K) == 0:
        return (a, b)

    c = max(zip(signed_dist, P))[1]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

示例數據輸入和輸出

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch)
[0, 4, 1, 3]

print(points[ch])
[[0. 0.]
 [1. 0.]
 [1. 1.]
 [0. 1.]]

要使用 Numba,此代碼中存在許多問題。

首先,在 Numba 中返回可變大小的元組是不可能的,因為元組的類型隱含地包含它的大小。 元組基本上是一種結構化類型,而不是列表。 有關此問題的更多信息,請參閱這篇文章這篇文章。 解決方案基本上是返回一個列表(慢)或一個數組(快)。

此外,參數的類型從一個函數變為另一個函數 實際上,在quickhull_2d中調用process時將P定義為 Numpy 數組,然后從process本身調用並將P定義為列表。 列表和數組是完全不同的東西。 最好在 Numba 中盡可能使用數組,除非您使用列表來添加未知數量的項目(既不小也不有限)。

此外, max(zip(signed_dist, P))[1]顯然不受 Numba 支持,而且它無論如何都不是很有效(對於 Numpy 代碼也不是慣用的)。 應該P[np.argmax(signed_dist)]

此外, np.cross似乎也不支持一般情況,您目前需要使用cross2d (來自numba.np.extensions )。

最后,當你像這樣使用遞歸函數時,最好指定參數的輸入類型,以避免出現奇怪的錯誤。 這可以通過簽名字符串來完成。

結果代碼是:

import numpy as np
from numba import njit
from numba.np.extensions import cross2d

@njit('(float64[:,:], int64[:], int64, int64)')
def process(S, P, a, b):
    signed_dist = cross2d(S[P] - S[a], S[b] - S[a])
    K = np.array([i for s, i in zip(signed_dist, P) if s > 0 and i != a and i != b], dtype=np.int64)

    if len(K) == 0:
        return [a, b]

    c = P[np.argmax(signed_dist)]
    return process(S, K, a, c)[:-1] + process(S, K, c, b)

@njit('(float64[:,:],)')
def quickhull_2d(S: np.ndarray) -> np.ndarray:
    a, b = np.argmin(S[:,0]), np.argmax(S[:,0])
    max_index = np.argmax(S[:,0])
    max_element = S[max_index]
    return process(S, np.arange(S.shape[0]), a, max_index)[:-1] + process(S, np.arange(S.shape[0]), max_index, a)[:-1]

points = np.array([[0, 0], [1, 1], [0.5, 0.5], [0, 1], [1, 0]])
ch = quickhull_2d(points)

print(ch) # print [0, 4, 1, 3]

請注意,編譯時間很慢,執行時間應該不會很大。 這是由於列表(以及用於運行時性能的臨時數組)。 下一步就是簡單地使用數組。 壞消息是concatenate不支持連接(因為一般情況下不容易實現,盡管具體情況很簡單)。 您可以創建一個新數組並復制每個部分(或者更好:您可以預分配一個數組並在遞歸調用期間對其進行切片)。

也不是說任何遞歸函數都可以使用手動堆棧轉換為非遞歸函數。 話雖如此,它可能會變慢並使代碼更加冗長。 不過,這種方法有一些好處:當遞歸很深時,它可以避免堆棧溢出,並且如果函數被重寫,它可能會更快,因為尾調用優化不會將函數調用之一堆疊起來。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM