簡體   English   中英

如何使 numba(nopython=true) 與元素數量未知的 1D numpy.ndarray 輸入一起使用

[英]How to make numba(nopython=true) work with 1D numpy.ndarray input with unknown number of elements

我正在將一個(數學復雜/涉及但操作很少)自制經驗分布 class 從 C++/MATLAB(我都有)移植到 Python。

該文件有大約 1100 行代碼,包括注釋和測試數據,包括

if __name__ == "__main__": 

在文件的底部。

第 83 行有 function 聲明: def cdf(self, x):

哪個編譯並運行良好,它只是非常慢,所以我想用@numba.jit(nopython=True)編譯以使其運行得更快。

但是,編譯在文件npts=len(x)的 function (僅前面的注釋)第 85 行的最早行之一上終止。

消息以:

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------

File "Canopy\scripts\empDist.py", line 85

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

現在我真的在文件頂部做了一個import numpy as np但是為了清楚下面的這條消息,我嘗試用numpy替換np 但我可能錯過了一些。

如果我使用npts=x.size ,我會收到相同的錯誤消息。

所以我嘗試輸入x為:

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(dtype=numpy.float64)):

我得到以下錯誤

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
C:\Users\kdalbey\Canopy\scripts\empDist.py in <module>()
     15 np.set_printoptions(precision=16)
     16 
---> 17 class empDist:
     18     def __init__(self, xdata):
     19         npts=len(xdata)
C:\Users\kdalbey\Canopy\scripts\empDist.py in empDist()
     81 
     82     @numba.jit(nopython=True)
---> 83     def cdf(self, x: np.ndarray(dtype=np.float64)):
     84         # compute the value of cdf at vector of points x
     85         npts = x.size
TypeError: Required argument 'shape' (pos 1) not found

但是我不知道1D numpy.ndarray有多少元素(隨意)

我猜我也許可以做一個

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(shape=(), dtype=numpy.float64)):

並且它僅通過 go 回到該錯誤

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------
File "Canopy\scripts\empDist.py", line 85
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

如果我執行npts=int(x.size)npts=numpy.int32(x.size)也是同樣的錯誤,所以我認為問題出在x上。

由於多個問題(從 numba 版本 0.46.0 開始),您的方法存在問題:

  • numpy.ndarray(shape=(), dtype=numpy.float64)真的試圖創建一個 NumPy 數組。 將它用作類型提示並不重要。 它仍然被執行(並且失敗)。
  • 您應該在jit中使用更合適的(對於 numba)簽名,而不是類型提示。 甚至更好:完全省略簽名,讓 numba 弄清楚。 在大多數情況下,numba 更勝一籌,並且花費更少的精力(如果您不需要限制類型)。
  • 你不能在jit模式下 jit 方法。 更好的方法是制作 function 並從您的方法中調用它。

所以在你的情況下:

import numba as nb

@nb.njit
def _cdf(x):
    # do something with x

class empDist:
    def cdf(self, x):
        result = _cds(x)
        ...

您的示例可能更復雜,但這應該為您提供一個很好的起點。 如果您需要使用實例屬性,則只需將它們傳遞給_cdf (如果 numba 支持它們)。


一般來說,嘗試在所有東西上使用 numba 並不是一個好主意。 Numba 的 scope 非常有限,但在它適用的地方,它可以是驚人的。

在您的情況下,您說它很慢。 那么第一步應該是分析你的代碼並找出它為什么慢以及在哪里。 然后嘗試找出是否可以用更快的方法解決這個瓶頸。 通常問題不在於代碼本身,而在於算法/方法。 檢查它是否使用次優方法。 如果它不是一個數字繁重的部分,那么使用 numba 可能是有意義的 - 但請注意:通常您根本不需要 numba,因為只需優化 NumPy 部件即可獲得足夠的性能。

好的...問題是它是一個方法(成員函數),我是從 MrFuppes 那里得到的。 將它隔離在它自己的 function 中,該方法調用的方法效果很好(幾乎沒有對 function 進行修改,在 numba 之前工作)。

順便說一句,我將嘗試獲得批准以發布/發布經驗分發代碼,但這還有一段路要走。 我也可能想學習 cython 並重新編碼以提高 cython 的速度,在我的機器上編譯需要 O(秒),因為這些操作在數學上很復雜/涉及但從失敗計數的角度來看並沒有很多。 與 sklearn.neighbors.kde 相比,我的經驗分布要快得多(在 @numba.jit(nopython=True) 編譯緩存之后/折扣之后)。 在 windows 上的樹冠中運行(使用 numba 0.36.2,因此 np.interp 沒有從 numba 中受益)構建這個經驗分布需要 5.72e-5 秒,而擬合sklearn kde 需要 2.03e-4 秒,獲得 463 點。 此外,它應該很好地擴展到非常多的點。 除了 O(n log(n)) 的快速排序和 O(n) 的插值之外,構造(以及存儲對象所需的 memory)成本為 O(n^(1/3))(具有顯着系數到 O(n^(1/3))。它具有 PDF、CDF 和逆 CDF 的“簡單”分析公式,因此經驗分布的評估速度也快了很多。它與 sklearn 具有可比/略好的准確性高斯的 KDE(使用帶寬 = (maxx-minx)*0.015 我復制了帶寬,所以其他人的代碼可能比我的 sklearn kde 更好,顯然 kde 的准確性很大程度上取決於帶寬,我的經驗分布在構建過程中不采用除數據以外的任何參數,它通過算法計算出它需要了解的關於數據的所有信息),並且對於具有有限尾部(例如均勻或指數)的東西具有顯着更好的准確性。提高的准確性部分來自於它悶悶不樂與 sklearn kde 相比,振動較小。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM