簡體   English   中英

如何使用numba加速python函數

[英]How to speed up a python function with numba

我正在嘗試使用 numba 加速我對Floyd-Steinberg 抖動算法的實現 在閱讀了初學者指南后,我在代碼中添加了@jit裝飾器:

def findClosestColour(pixel):
    colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]])
    distances = np.sum(np.abs(pixel[:, np.newaxis].T - colors), axis=1)
    shortest = np.argmin(distances)
    closest_color = colors[shortest]
    return closest_color

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def floydDither(img_array):
    height, width, _ = img_array.shape
    for y in range(0, height-1):
        for x in range(1, width-1):
            old_pixel = img_array[y, x, :]
            new_pixel = findClosestColour(old_pixel)
            img_array[y, x, :] = new_pixel
            quant_error = new_pixel - old_pixel
            img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
            img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
            img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
            img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
    return img_array

但是,我拋出以下錯誤:

Untyped global name 'findClosestColour': Cannot determine Numba type of <class 'function'>

我想我知道 numba 不知道findClosestColour的類型,但我剛開始使用 numba 並且不知道如何處理錯誤。

這是我用來測試該功能的代碼:

image = cv2.imread('logo.jpeg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_out = floydDither(img)

這是我使用的測試圖像。

首先,不能從 Numba nopython jitted 函數(又名 njit 函數)調用純 Python 函數 這是因為 Numba 需要在編譯時跟蹤類型以生成高效的二進制文件。

此外,Numba 無法編譯表達式pixel[:, np.newaxis].T因為np.newaxis似乎尚不受支持(可能是因為np.newaxisNone )。 您可以使用pixel.reshape(3, -1).T代替

請注意,您應該注意類型,因為當兩個變量都是np.uint8類型時執行a - b np.uint8導致可能的溢出(例如0 - 1 == 255 ,或者更令人驚訝的是: 0 - 256 = 65280b是文字整數和a類型的np.uint8 )。 請注意,該數組是就地計算的,並且像素是在之前寫入的


盡管 Numba 做得很好,但生成的代碼效率不會很高。 您可以使用循環自己迭代顏色以找到最小索引。 這要好一些,因為它不會生成很多小的臨時數組 您還可以指定類型,以便 Numba 提前編譯函數。 話雖如此。 這也使代碼級別較低,因此更冗長/更難以維護。

這是一個優化的實現

@nb.njit('int32[::1](uint8[::1])')
def nb_findClosestColour(pixel):
    colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]], dtype=np.int32)
    r,g,b = pixel.astype(np.int32)
    r2,g2,b2 = colors[0]
    minDistance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
    shortest = 0
    for i in range(1, colors.shape[0]):
        r2,g2,b2 = colors[i]
        distance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
        if distance < minDistance:
            minDistance = distance
            shortest = i
    return colors[shortest]

@nb.njit('uint8[:,:,::1](uint8[:,:,::1])')
def nb_floydDither(img_array):
    assert(img_array.shape[2] == 3)
    height, width, _ = img_array.shape
    for y in range(0, height-1):
        for x in range(1, width-1):
            old_pixel = img_array[y, x, :]
            new_pixel = nb_findClosestColour(old_pixel)
            img_array[y, x, :] = new_pixel
            quant_error = new_pixel - old_pixel
            img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
            img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
            img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
            img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
    return img_array

naive 版本快 14 倍,而最后一個版本快19 倍

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM