簡體   English   中英

在 python numba jitclass 中調用 njit function 失敗

[英]calling njit function in python numba jitclass fails

@njit
def cumutrapz(x:np.array, y:np.array):
    return np.append(0, [
        np.trapz(y=y[i-2:i], x=x[i-2:i]) for i in range(2, len(x) + 1)]).cumsum()

from numba import float64
@jitclass([
    ('a', float64[:]),
    ('b', float64[:]),    
    ('c', float64[:]),    
])
class Testaroo(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.c = np.zeros(len(self.a), dtype=np.float64)
        
    def set_c(self):
        self.c = cumutrapz(self.a, self.b)
        
testaroo = Testaroo(  
    np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()

以上失敗,但以下兩個非常相似的示例有效:

cumutrapz(np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))

from numba import float64
@jitclass([
    ('a', float64[:]),
    ('b', float64[:]),    
    ('c', float64[:]),    
])
class Testaroo(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
        self.c = np.zeros(len(self.a), dtype=np.float64)
        
    def set_c(self):
        self.c = (self.a * self.b).cumsum()
        
testaroo = Testaroo(  
   np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()

后一個例子現在對我有用,但我想知道是否有辦法讓cumutrapz function 在jitclass內工作。

我正在使用 numba 版本“0.53.1”。

仔細閱讀您可以找到的長錯誤消息:

No implementation of function Function(<function trapz at 0x7f7e9b21e5e0>) 
found for signature:
>>> trapz(y=array(float64, 1d, A), x=array(float64, 1d, A))
...
reshape() supports contiguous array only

格式 A(任意)的 Arrays 不一定是連續的。

您可以確保function 僅處理連續的 arrays:

@njit([nb.float64[::1](nb.float64[::1], nb.float64[::1])])
def cumutrapz(x, y):
    ...

然后出現一個新的錯誤:

Invalid use of type(CPUDispatcher(<function MyTestCase.test_cumutrapz.<locals>.cumutrapz at 0x7f8e15a841f0>))
with parameters (array(float64, 1d, A), array(float64, 1d, A))
Known signatures:
    * (array(float64, 1d, C), array(float64, 1d, C)) -> array(float64, 1d, C)
...
    self.c = cumutrapz(self.a, self.b)
    ^

所以 class 中的 arrays 是不連續的。

為了確保它們是,您可以將 class 規范更改為:

@jitclass([
    ('a', nb.float64[::1]),
    ('b', nb.float64[::1]),
    ('c', nb.float64[::1]),
    ])

現在它可以工作了(用 Numba 0.54.0 測試)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM