[英]how to use index of numpy.array in numba.njit()?
如何在 numba.njit() 中使用 numpy.array 的索引? 在下文中,如果使用 numba.njit,代码将退出并报错。 我发现错误归因于“b = a [idx]”。 但实际上,在 python 中应该是对的。 如何在 numba 中纠正它? 谢谢
@numba.njit()
def test(a):
idx = np.where(a>5)
b = a[idx]
return b
a = np.linspace(0,15,16).reshape([4,4])
b = test(a)
文档说还支持高级索引的子集:只允许一个高级索引,并且它必须是一维数组。
如果你在没有 numba 的情况下运行你的代码,你可以看到结果是一个一维数组:
>>> a[np.where(a > 5)]
array([ 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.])
所以你可以直接对一维数组进行操作:
@nb.njit()
def test(a):
a = a.ravel()
idx = np.where(a > 5)
b = a[idx]
return b
或者更简单:
@nb.njit()
def test(a):
a = a.ravel()
return a[a > 5]
查看此文档
http://numba.pydata.org/numba-doc/0.15.1/arrays.html
b = 测试(a)
尝试在这里更改变量,例如
k=test(a),因为我认为 test(a) = b 这意味着 b=b
只要试试它是否有效......
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.