简体   繁体   English

在 numba njit 环境中将形状传递给 numpy.reshape 失败,如何为目标形状创建合适的可迭代对象?

[英]Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?

I have a function that takes in an array, performs an arbitrary calculation and returns a new shape in which it can be broadcasted.我有一个函数,它接受一个数组,执行任意计算并返回一个可以广播的新形状。 I would like to use this function in a numba.njit environment:我想在numba.njit环境中使用这个函数:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

However, tuple creation is not supported in numba and I get the following error message when trying to cast the result of generate_target_shape to a tuple with the tuple() operator:但是,numba 不支持创建元组,当我尝试使用tuple()运算符将generate_target_shape的结果转换为元组时,出现以下错误消息:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

If I try to change the return type of generate_target_shape from tuple to list or np.array , I receive the following error message:如果我尝试将generate_target_shape的返回类型从tuple更改为listnp.array ,我会收到以下错误消息:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

Is there a way for me to create an iterable object inside a nb.njit function that can be passed to np.reshape ?有没有办法在nb.njit函数中创建一个可传递给np.reshape的可迭代对象?

It seems like the standard python function tuple() is not supported by numba. numba 似乎不支持标准的 python 函数tuple() You can easily work around this issue by rewriting your code a litte bit:您可以通过稍微重写代码来轻松解决此问题:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

The general case however, is a lot trickier.然而,一般情况要棘手得多。 I am going to backtrack on what i said in the comments: it is not possible or advisable to make a numba compiled function that works with tuples of many different sizes.我将回溯我在评论中所说的话:制作一个适用于许多不同大小的元组的 numba 编译函数是不可能或不可取的。 Doing so would require you to recompile your function for every tuple of an unique size.这样做将需要您为每个具有唯一大小的元组重新编译您的函数。 @Jérôme Richard explains the problem very well in this stackoverflow answer . @Jérôme Richard 在这个 stackoverflow 答案中很好地解释了这个问题。

What i would recommend that you do, is to simply take the array containing the shape, and your data, and calculate my_array.reshape(tuple(target_shape)) outside of your numba compiled function.我建议你做的是简单地获取包含形状和数据的数组,并在你的 numba 编译函数之外计算my_array.reshape(tuple(target_shape)) It is not pretty, but it will allow you to continue with your project.它不漂亮,但它可以让你继续你的项目。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM