[英]Correct usage of numpy.vstack with python's numba
我目前正在嘗試使用numba
加速一些 python 代碼。 根據numba 的文檔,不支持numpy.vstack
numpy.meshgrid
。 所以我用vstack
替換了meshgrid
調用,它在不使用numba
時工作正常。 但是,在使用numba
時它不起作用。
這是代碼:
import numpy as np
import timeit
from numba import njit, prange
@njit(parallel=True)
def method1_par1(n):
# create some data
x = np.linspace(0,2*np.pi,n)
y = np.linspace(0,2*np.pi,n)
# np.meshgrid not supported by numba, therefore trying to use np.vstack
X = np.vstack( n*[x] )
Y = np.hstack( n*[y[:,np.newaxis]] ) # tranform row vector into column vector, then duplicate
Z = np.sin(X) * np.sin(Y)
# calculate centered finite difference using for loop
Z_diff1 = Z*.0
for ii in prange(1,Z.shape[0]-1,2):
for jj in range(1,Z.shape[1]-1,2):
Z_diff1[ii,jj] = Z[ii+1, jj] - Z[ii-1,jj]
runs = 1000
print( min(timeit.repeat( "method1_par1(50)", "from __main__ import method1_par1",
number=runs )) )
這是錯誤消息:
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(array(float64, 1d, C), Tuple(slice<a:b>, none))
There are 22 candidate implementations:
- Of which 20 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
No match.
- Of which 2 did not match due to:
Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
Rejected as the implementation raised a specific error:
TypeError: unsupported array index type none in Tuple(slice<a:b>, none)
raised from /usr/local/lib/python3.8/dist-packages/numba/core/typing/arraydecl.py:68
During: typing of intrinsic-call at forloop_slicing_parallel_q.py (12)
During: typing of static-get-item at forloop_slicing_parallel_q.py (12)
File "forloop_slicing_parallel_q.py", line 12:
def method1_par1(n):
<source elided>
# np.meshgrid not supported by numba, therefore trying to use np.vstack
Y = np.hstack( n*[y[:,np.newaxis]] )
^
在我看來,不支持我使用vstack
的方式,對嗎?
我正在使用的版本:
numpy: 1.20.3
numba: 0.54.1
python: 3.8.10
Update1:使用np.concatenate( n*[[x]] )
會導致同樣的問題
如果您仔細查看錯誤消息,您會看到它說
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(array(float64, 1d, C), Tuple(slice<a:b>, none))
getitem
是 numba 編譯[]
運算符的方式。 簽名表明 numba 不支持像array[slice, None]
這樣的調用。 具體來說,問題是
y[:, np.newaxis]
Numba 確實支持重塑操作,因此您可以將該行更改為
Y = np.hstack(n * [y.reshape(-1, 1)])
在這種情況下,您還可以考慮使用repeat
而不是列表乘法和堆疊:
Y = y.reshape(-1, 1).repeat(n, 1)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.