簡體   English   中英

用python的numba正確使用numpy.vstack

[英]Correct usage of numpy.vstack with python's numba

我目前正在嘗試使用numba加速一些 python 代碼。 根據numba 的文檔,不支持numpy.vstack numpy.meshgrid 所以我用vstack替換了meshgrid調用,它在不使用numba時工作正常。 但是,在使用numba時它不起作用。

這是代碼:

import numpy as np
import timeit
from numba import njit, prange

@njit(parallel=True)
def method1_par1(n):
    # create some data
    x   = np.linspace(0,2*np.pi,n)
    y   = np.linspace(0,2*np.pi,n)
    # np.meshgrid not supported by numba, therefore trying to use np.vstack
    X   = np.vstack( n*[x] )
    Y   = np.hstack( n*[y[:,np.newaxis]] )    # tranform row vector into column vector, then duplicate
    Z   = np.sin(X) * np.sin(Y)

    # calculate centered finite difference using for loop
    Z_diff1 = Z*.0
    for ii in prange(1,Z.shape[0]-1,2):
        for jj in range(1,Z.shape[1]-1,2):
            Z_diff1[ii,jj]  = Z[ii+1, jj] - Z[ii-1,jj]

runs    = 1000
print( min(timeit.repeat( "method1_par1(50)", "from __main__ import method1_par1",
           number=runs )) )

這是錯誤消息:

No implementation of function Function(<built-in function getitem>) found for signature:

>>> getitem(array(float64, 1d, C), Tuple(slice<a:b>, none))

There are 22 candidate implementations:
  - Of which 20 did not match due to:
  Overload of function 'getitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
    With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
  Rejected as the implementation raised a specific error:
    TypeError: unsupported array index type none in Tuple(slice<a:b>, none)
 raised from /usr/local/lib/python3.8/dist-packages/numba/core/typing/arraydecl.py:68

During: typing of intrinsic-call at forloop_slicing_parallel_q.py (12)
During: typing of static-get-item at forloop_slicing_parallel_q.py (12)

File "forloop_slicing_parallel_q.py", line 12:
def method1_par1(n):
    <source elided>
    # np.meshgrid not supported by numba, therefore trying to use np.vstack
    Y   = np.hstack( n*[y[:,np.newaxis]] )
    ^

在我看來,不支持我使用vstack的方式,對嗎?

我正在使用的版本:

numpy:  1.20.3
numba:  0.54.1
python: 3.8.10

Update1:使用np.concatenate( n*[[x]] )會導致同樣的問題

如果您仔細查看錯誤消息,您會看到它說

No implementation of function Function(<built-in function getitem>) found for signature:

>>> getitem(array(float64, 1d, C), Tuple(slice<a:b>, none))

getitem是 numba 編譯[]運算符的方式。 簽名表明 numba 不支持像array[slice, None]這樣的調用。 具體來說,問題是

y[:, np.newaxis]

Numba 確實支持重塑操作,因此您可以將該行更改為

Y = np.hstack(n * [y.reshape(-1, 1)])

在這種情況下,您還可以考慮使用repeat而不是列表乘法和堆疊:

Y = y.reshape(-1, 1).repeat(n, 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM