繁体   English   中英

在 numba njit 环境中将形状传递给 numpy.reshape 失败,如何为目标形状创建合适的可迭代对象?

[英]Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?

我有一个函数,它接受一个数组,执行任意计算并返回一个可以广播的新形状。 我想在numba.njit环境中使用这个函数:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

但是,numba 不支持创建元组,当我尝试使用tuple()运算符将generate_target_shape的结果转换为元组时,出现以下错误消息:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

如果我尝试将generate_target_shape的返回类型从tuple更改为listnp.array ,我会收到以下错误消息:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

有没有办法在nb.njit函数中创建一个可传递给np.reshape的可迭代对象?

numba 似乎不支持标准的 python 函数tuple() 您可以通过稍微重写代码来轻松解决此问题:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

然而,一般情况要棘手得多。 我将回溯我在评论中所说的话:制作一个适用于许多不同大小的元组的 numba 编译函数是不可能或不可取的。 这样做将需要您为每个具有唯一大小的元组重新编译您的函数。 @Jérôme Richard 在这个 stackoverflow 答案中很好地解释了这个问题。

我建议你做的是简单地获取包含形状和数据的数组,并在你的 numba 编译函数之外计算my_array.reshape(tuple(target_shape)) 它不漂亮,但它可以让你继续你的项目。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM