[英]How can I make this function taking an array of arrays as input compile with numba?
函數的簽名是
def SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000):
它返回一個 float64。
參數的類型是:
Q -- np.array([[1., 2.], [4., 5.]])
q -- np.array([1.,2.,3.,4.])
u -- a scalar
a -- np.array([1.,2.,3.,4.])
x -- np.array([1.,2.,3.,4.])
我試過
@jit('f8(f8[:,:], f8[:], f8, f8[:], f8[:], f8, i4)',nopython=True)
def SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000):
它給了我這個錯誤:
Invalid use of Function(<built-in function array>) with argument(s) of type(s): (array(float64, 1d, C))
* parameterized
In definition 0:
TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
raised from /Users/gerardozinno/Desktop/ProgettoML/venv/lib/python3.8/site-packages/numba/typing/npydecl.py:472
In definition 1:
TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
raised from /Users/gerardozinno/Desktop/ProgettoML/venv/lib/python3.8/site-packages/numba/typing/npydecl.py:472
我也試過這個:
@jit('numba.float64(numba.array(float64, 2d, C), numba.array(float64, 1d, C), numba.float64, numba.array(float64, 1d, C), numba.array(float64, 1d, C), numba.float64, numba.int64)',nopython=True)
它給了我一個語法錯誤。
編輯:
我試過這個簽名:
@nb.njit('f8(f8[:,:], f8[:], f8, f8[:], f8[:], f8, i4)')
由 Thane Brooker 在答案部分建議,它給了我這個錯誤:
Invalid use of Function(<built-in function array>) with argument(s) of type(s): (array(float64, 1d, C))
* parameterized
In definition 0:
TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
raised from /Users/gerardozinno/Desktop/ProgettoML/venv/lib/python3.8/site-packages/numba/typing/npydecl.py:472
In definition 1:
TypingError: array(float64, 1d, C) not allowed in a homogeneous sequence
raised from /Users/gerardozinno/Desktop/ProgettoML/venv/lib/python3.8/site-packages/numba/typing/npydecl.py:472
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function array>)
[2] During: typing of call at /Users/gerardozinno/Desktop/NUOVO/ProgettoML/svr/SLBQP.py (119)
File "SLBQP.py", line 119:
def SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000):
<source elided>
v = np.dot(Qx,x) + np.dot(q, x)
g = np.array(Qx+q)
^
這工作沒有錯誤。
import numba as nb
import numpy as np
@nb.njit('f8(f8[:,:], f8[:], f8, f8[:], f8[:], f8, i4)')
def SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000):
return 1.
Q = np.array([[1., 2.], [4., 5.]])
q = np.array([1.,2.,3.,4.])
u = 50
a = np.array([1.,2.,3.,4.])
x = np.array([1.,2.,3.,4.])
result = SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000)
我更改了您的示例 Q 變量(我認為這是一個錯字),但否則我無法復制您的語法錯誤。 我猜你傳遞給函數的 Q 是一維數組,而不是你認為的二維數組。 查看Q.shape
和Q.flags
進行檢查。
我通過在我的函數之前編寫這段代碼解決了這個問題。
from numba.extending import overload
@overload(np.array)
def np_array_ol(x):
if isinstance(x, types.Array):
def impl(x):
return np.copy(x)
return impl
@nb.njit('f8(f8[:,:], f8[:], f8, f8[:], f8[:], f8, i4)')
def SLBQP(Q, q, u, a, x, eps=1e-6, maxIter=1000):
...
顯然,編輯部分中寫入的錯誤是由 numba 內部的錯誤引起的。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.