简体   繁体   English

使用 Numba 提取 numpy 数组中的特定行

[英]Extracting specific rows in numpy array using Numba

I have the following array:我有以下数组:

import numpy as np
from numba import njit


test_array = np.random.rand(4, 10)

I create a "jitted" function that slices the array and does some operations afterwards:我创建了一个“jitted” function 切片数组并在之后执行一些操作:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = test_array[[0,1,3]]

   return test_array_sliced

However, Numba throws the following error:但是,Numba 会引发以下错误:

In definition 11:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.

Workaround解决方法

I have tried to delete the rows I do not need by using np.delete , but since I have to specify an axis Numba throws the following error:我尝试使用np.delete删除不需要的行,但由于我必须指定axis Numba 会引发以下错误:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = np.delete(test_array, obj = 2, axis = 0)

   return test_array_sliced

In definition 1:
    TypeError: np_delete() got an unexpected keyword argument 'axis'
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.

Any ideas of how to extract specific rows under Numba?关于如何在 Numba 下提取特定行的任何想法?

I think it will work (it seems to suggest so in the docs ) if you index with an array instead of a list:如果您使用数组而不是列表进行索引,我认为它会起作用(似乎在文档中建议这样做):

test_array_sliced = array[np.array([0,1,3])]

(I changed the array you're slicing to array , which is what you pass in to the function. Maybe it was intentional, but be careful with globals!) (我将您要切片的数组更改为array ,这是您传递给 function 的内容。也许这是故意的,但要小心全局变量!)

Numba does not support numpy fancy indexing. Numba 不支持 numpy 花式索引。 I'm not 100% sure what your real use case looks like, but a simple way to do it would be something like:我不是 100% 确定你的真实用例是什么样的,但一个简单的方法是这样的:

import numpy as np
import numba as nb

@nb.njit
def test_func(x):
    idx = (0, 1, 3)
    res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
    for i, ix in enumerate(idx):
        res[i] = x[ix]

    return res

test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))

Edit: @kwinkunks is correct, and my original answer made an incorrect blanket statement that fancy indexing was not supported.编辑: @kwinkunks 是正确的,我最初的回答是不支持花式索引的不正确的一揽子声明。 It is in a limited set of cases, including this one.这是在一组有限的情况下,包括这个。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM