簡體   English   中英

在 numba njit 環境中將形狀傳遞給 numpy.reshape 失敗,如何為目標形狀創建合適的可迭代對象?

[英]Passing a shape to numpy.reshape in a numba njit environment fails, how can I create a suitable iterable for the target shape?

我有一個函數,它接受一個數組,執行任意計算並返回一個可以廣播的新形狀。 我想在numba.njit環境中使用這個函數:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

但是,numba 不支持創建元組,當我嘗試使用tuple()運算符將generate_target_shape的結果轉換為元組時,出現以下錯誤消息:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

如果我嘗試將generate_target_shape的返回類型從tuple更改為listnp.array ,我會收到以下錯誤消息:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

有沒有辦法在nb.njit函數中創建一個可傳遞給np.reshape的可迭代對象?

numba 似乎不支持標准的 python 函數tuple() 您可以通過稍微重寫代碼來輕松解決此問題:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

然而,一般情況要棘手得多。 我將回溯我在評論中所說的話:制作一個適用於許多不同大小的元組的 numba 編譯函數是不可能或不可取的。 這樣做將需要您為每個具有唯一大小的元組重新編譯您的函數。 @Jérôme Richard 在這個 stackoverflow 答案中很好地解釋了這個問題。

我建議你做的是簡單地獲取包含形狀和數據的數組,並在你的 numba 編譯函數之外計算my_array.reshape(tuple(target_shape)) 它不漂亮,但它可以讓你繼續你的項目。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM