繁体   English   中英

Numba - 使用可变索引访问 Numpy 数组

[英]Numba - Accessing Numpy Array with variable index

我正在研究使用 numba 来加速迭代计算,特别是在计算有时依赖于先前计算的结果并且因此矢量化并不总是适用的情况下。 我发现缺少的一件事是它似乎不允许数据帧。 虽然我认为没问题,但您可以传递一个 2D numpy 数组和一个 numpy 列名数组,并且我尝试实现一个函数以通过列名而不是索引来引用值。 这是我到目前为止的代码。

from numba import jit
import numpy as np
@jit(nopython=True)
def get_index(cols,col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
@jit(nopython=True)
def get_element(ndarr: np.ndarray,cols:np.ndarray,row:np.int8,name:str):
    ind = get_index(cols,name)
    print(row)
    print(ind)
    print(ndarr[0][0])
    #print(ndarr[row][ind])
get_element(np.array([['HI'],['BYE'],['HISAHASDG']]),np.array(['COLUMN_1']),0,"COLUMN_1")

我有 get_index,我已经独立测试过它并且可以正常工作。 这基本上是 np.where 的一个实现,我想知道这是否会导致我的错误。 因此,将打印注释掉后,此代码现在可以运行。 它按预期打印出 0、0,然后是“HI”。 所以理论上所有注释掉的行应该做的是打印“HI”,就像前一行的打印一样,因为 row 和 ind 都是 0。但是当我取消注释时,我得到以下信息:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<timed exec> in <module>

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

/sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array([unichr x 50], 1d, C), OptionalType(int64) i.e. the type 'int64 or None')
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
In definition 10:
    All templates rejected with literals.
In definition 11:
    All templates rejected without literals.
In definition 12:
    TypeError: unsupported array index type OptionalType(int64) i.e. the type 'int64 or None' in [OptionalType(int64)]
    raised from /sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/typing/arraydecl.py:69
In definition 13:
    TypeError: unsupported array index type OptionalType(int64) i.e. the type 'int64 or None' in [OptionalType(int64)]
    raised from /sas/python/app/miniconda3/envs/py3lu/lib/python3.6/site-packages/numba/core/typing/arraydecl.py:69
In definition 14:
    All templates rejected with literals.
In definition 15:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at <timed exec> (15)

File "<timed exec>", line 15:
<source missing, REPL/exec in use?>

有什么我想念的吗? 我检查了 row 和 ind 的类型,它们确实是 int 类型。 为什么 numba 不让我使用 int 变量进行子集化? 谢谢。

numba在这里真的很聪明! 考虑将col传递给不在cols get_index时会发生什么。 cols[i] == col永远不会为真,循环将退出,并且由于函数末尾没有笼统的return ,返回值将为None

因此numba正确推断get_index的返回类型是OptionalType(int64)即一个值可能是 int64 或None 但是None不是索引的有效类型,因此您不能使用可能为None的值来索引数组。

您可以通过在最后添加一个笼统的return来解决这个问题。

@jit(nopython=True)
def get_index(cols, col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
    return -1

当然,在这种情况下,这可能不是您想要的行为; 引发异常可能更好, numba也可以正确处理。

@jit(nopython=True)
def get_index(cols, col):
    for i in range(len(cols)):
        if cols[i] == col:
            return i
    raise IndexError('list index out of range')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM