繁体   English   中英

如何在 numba.njit() 中使用 numpy.array 的索引?

[英]how to use index of numpy.array in numba.njit()?

如何在 numba.njit() 中使用 numpy.array 的索引? 在下文中,如果使用 numba.njit,代码将退出并报错。 我发现错误归因于“b = a [idx]”。 但实际上,在 python 中应该是对的。 如何在 numba 中纠正它? 谢谢

@numba.njit()
def test(a):
    idx = np.where(a>5)
    b   = a[idx]
    return b

a = np.linspace(0,15,16).reshape([4,4])
b = test(a)

文档还支持高级索引的子集:只允许一个高级索引,并且它必须是一维数组

如果你在没有 numba 的情况下运行你的代码,你可以看到结果是一个一维数组:

>>> a[np.where(a > 5)]
array([ 6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15.])

所以你可以直接对一维数组进行操作:

@nb.njit()
def test(a):
    a = a.ravel()
    idx = np.where(a > 5)
    b = a[idx]
    return b

或者更简单:

@nb.njit()
def test(a):
    a = a.ravel()
    return a[a > 5]

查看此文档

http://numba.pydata.org/numba-doc/0.15.1/arrays.html

b = 测试(a)

尝试在这里更改变量,例如

k=test(a),因为我认为 test(a) = b 这意味着 b=b

只要试试它是否有效......

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM