简体   繁体   中英

Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?

I have a function that takes in an array, performs an arbitrary calculation and returns a new shape in which it can be broadcasted. I would like to use this function in a numba.njit environment:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

However, tuple creation is not supported in numba and I get the following error message when trying to cast the result of generate_target_shape to a tuple with the tuple() operator:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

If I try to change the return type of generate_target_shape from tuple to list or np.array , I receive the following error message:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

Is there a way for me to create an iterable object inside a nb.njit function that can be passed to np.reshape ?

It seems like the standard python function tuple() is not supported by numba. You can easily work around this issue by rewriting your code a litte bit:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

The general case however, is a lot trickier. I am going to backtrack on what i said in the comments: it is not possible or advisable to make a numba compiled function that works with tuples of many different sizes. Doing so would require you to recompile your function for every tuple of an unique size. @Jérôme Richard explains the problem very well in this stackoverflow answer .

What i would recommend that you do, is to simply take the array containing the shape, and your data, and calculate my_array.reshape(tuple(target_shape)) outside of your numba compiled function. It is not pretty, but it will allow you to continue with your project.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM