繁体   English   中英

numba 无法将 Numpy 数组形状识别为 int

[英]Numpy array shape not recognized as int by numba

我想基于一些数组生成一个 numpy 矩阵,并使用jit或理想情况下njit加速这一生成。 如果nopython=False (启用nopython失败),它会继续发送以下 2 个警告,我无法理解:

:14: NumbaWarning: 由于函数“process_stuffs”类型推断失败,编译正在回退到启用 looplifting 的对象模式,原因是:没有从 array(int32, 2d, C) 到 array(int64, 2d, A) 的“inp”转换, 定义为无

文件“”,第 23 行: def process_stuffs(output,inp,route1, route2, zoneidx):

 input_pallets, _ = inp.shape ^

期间:在 (23) 处输入参数

文件“”,第 23 行: def process_stuffs(output,inp,route1, route2, zoneidx):

 input_pallets, _ = inp.shape ^

@jit(nopython=False, :14: NumbaWarning: 由于函数“process_stuffs”类型推断失败,因此编译正在回退到对象模式而没有启用 looplifting:无法确定 <class 'numba.core.dispatcher.LiftedLoop'> 的 Numba 类型

文件“”,第 25 行: def process_stuffs(output,inp,route1, route2, zoneidx):

 for minute in range(input_pallets): ^

@jit(nopython=False, C:\\Anaconda3\\envs\\dev38\\lib\\site-packages\\numba\\core\\object_mode_passes.py:151: NumbaWarning: 函数“process_stuffs”是在没有 forceobj=True 的对象模式下编译的,但是有提升循环。

虽然它是真实的函数中使用复杂类型,它无法在在确定的长度开始inp数组,然后它不希望产生一个循环,虽然我看到过很多这样的例子。

我试图通过使用locals指定类型来纠正错误,但正如您所见,它没有帮助。

这是一个最小的工作代码:

zoneidx=Dict.empty(key_type=types.unicode_type,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)


output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)

@jit(nopython=False,
     locals={'input_pallets':numba.int64,
             'step':numba.int64,
             'inp':numba.types.int64[:,:],
             'route1':numba.types.unicode_type[:],
             'route2':numba.types.unicode_type[:],
             'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):

    input_pallets, _ = inp.shape

    for minute in range(input_pallets):
        prod1, prod2 = inp[minute]
        if prod1+prod2 <1:
            continue

        if prod1:
            routing = route1
            number_of_pallets = prod1
            number_of_steps = route1.shape[0]
        else:
            routing = route2
            number_of_pallets = prod2
            number_of_steps = route2.shape[0]
        for step in range(number_of_steps):
            zone = routing[step]
            output[minute+step,zoneidx[zone]]+=number_of_pallets

    return output




numba.__version__ == 0.53.1
numpy.__version__ == 1.19.2

我的代码有什么问题?

注意:我对代码输出的正确性不感兴趣,我知道只要“inp”激活“route1”,就会忽略“route2”。 我只想编译它。

警告信息具有误导性。 实际上输入的类型确实没有正确给出,它与.shape方法无关。

我的解决方案是使用numba.typeof函数来告诉它期望的类型。 例如,int32 是预期的,而不是“inp”的 64。 并且“unichr”是预期的,而不是unicode。

这是我的最小示例的工作版本:

zoneidx=Dict.empty(key_type=numba.typeof(route1).dtype,value_type=types.int8)
zoneidx["A"]=np.int8(0)
zoneidx["B"]=np.int8(1)
zoneidx["C"]=np.int8(2)
zoneidx["D"]=np.int8(3)
zoneidx["E"]=np.int8(4)


output = np.zeros(shape=(110,5),dtype=np.int64)
inp = np.random.randint(0,2,size=(100,2))
route1 = np.random.choice(list('ABCDE'),size=10)
route2 = np.random.choice(list('ABCDE'),size=10)

@jit(nopython=False,
     locals={'input_pallets':numba.int64,
             'step':numba.int64,
             'inp':numba.types.int32[:,:],
             'route1':numba.typeof(route1),
             'route2':numba.typeof(route1),
             'output':numba.types.int64[:,:]})
def process_stuffs(output,inp,route1, route2, zoneidx):

    input_pallets, _ = inp.shape

    for minute in range(input_pallets):
        prod1, prod2 = inp[minute]
        if prod1+prod2 <1:
            continue

        if prod1:
            routing = route1
            number_of_pallets = prod1
            number_of_steps = route1.shape[0]
        else:
            routing = route2
            number_of_pallets = prod2
            number_of_steps = route2.shape[0]
        for step in range(number_of_steps):
            zone = routing[step]
            output[minute+step,zoneidx[zone]]+=number_of_pallets

    return output

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM