[英]Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?
I have a function that takes in an array, performs an arbitrary calculation and returns a new shape in which it can be broadcasted.我有一个函数,它接受一个数组,执行任意计算并返回一个可以广播的新形状。 I would like to use this function in a
numba.njit
environment:我想在
numba.njit
环境中使用这个函数:
import numpy as np
import numba as nb
@nb.njit
def generate_target_shape(my_array):
### some functionality that calculates the desired target shape ###
return tuple([2,2])
@nb.njit
def test():
my_array = np.array([1,2,3,4])
target_shape = generate_target_shape(my_array)
reshaped = my_array.reshape(target_shape)
print(reshaped)
test()
However, tuple creation is not supported in numba and I get the following error message when trying to cast the result of generate_target_shape
to a tuple with the tuple()
operator:但是,numba 不支持创建元组,当我尝试使用
tuple()
运算符将generate_target_shape
的结果转换为元组时,出现以下错误消息:
No implementation of function Function(<class 'tuple'>) found for signature:
>>> tuple(list(int64)<iv=None>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
With argument(s): '(list(int64)<iv=None>)':
No match.
During: resolving callee type: Function(<class 'tuple'>
If I try to change the return type of generate_target_shape
from tuple
to list
or np.array
, I receive the following error message:如果我尝试将
generate_target_shape
的返回类型从tuple
更改为list
或np.array
,我会收到以下错误消息:
Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))
Is there a way for me to create an iterable object inside a nb.njit
function that can be passed to np.reshape
?有没有办法在
nb.njit
函数中创建一个可传递给np.reshape
的可迭代对象?
It seems like the standard python function tuple()
is not supported by numba. numba 似乎不支持标准的 python 函数
tuple()
。 You can easily work around this issue by rewriting your code a litte bit:您可以通过稍微重写代码来轻松解决此问题:
import numpy as np
import numba as nb
@nb.njit
def generate_target_shape(my_array):
### some functionality that calculates the desired target shape ###
a, b = [2, 2] # (this will also work if the list is a numpy array)
return a, b
The general case however, is a lot trickier.然而,一般情况要棘手得多。 I am going to backtrack on what i said in the comments: it is not possible or advisable to make a numba compiled function that works with tuples of many different sizes.
我将回溯我在评论中所说的话:制作一个适用于许多不同大小的元组的 numba 编译函数是不可能或不可取的。 Doing so would require you to recompile your function for every tuple of an unique size.
这样做将需要您为每个具有唯一大小的元组重新编译您的函数。 @Jérôme Richard explains the problem very well in this stackoverflow answer .
@Jérôme Richard 在这个 stackoverflow 答案中很好地解释了这个问题。
What i would recommend that you do, is to simply take the array containing the shape, and your data, and calculate my_array.reshape(tuple(target_shape))
outside of your numba compiled function.我建议你做的是简单地获取包含形状和数据的数组,并在你的 numba 编译函数之外计算
my_array.reshape(tuple(target_shape))
。 It is not pretty, but it will allow you to continue with your project.它不漂亮,但它可以让你继续你的项目。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.