簡體   English   中英

將 function 及其所有子函數傳遞給 njit

[英]Pass a function and all its subfunctions into njit

所以我最近發現了 Numba,我對此感到非常驚訝。 When trying it out I've used a bubblesort function as the test function, but since my bubblesort function calls another function I get errors when calling njit on it.

我已經解決了這個問題,首先在我的bubblesort子函數上調用njit,然后讓我的bubblesort調用njit子函數,它可以工作,但它迫使我在嘗試比較時定義兩個bubblesort函數。 我想知道是否有另一種方法可以做到這一點。

這就是我正在做的事情:

def bytaintill(l):
    changed = False
    for i in range(len(l) - 1):
        if l[i] > l[i+1]:
            l[i], l[i+1] = l[i+1], l[i]
            changed = True
    return changed


bytaintill_njit = njit()(bytaintill)

def bubblesort(l):
    not_done = True
    while not_done:
        not_done = bytaintill_njit(l)
    return

def bubble(l):
    not_done = True
    while not_done:
        not_done = bytaintill(l)
    return

bubblesort_njit = njit()(bubblesort)

要擴展我的評論,您不需要定義新功能,但也可以將 jit-ed 版本 map 改成同名。 通常,最方便的方法是使用@jit裝飾器(或@njit ,它是@jit(nopython=True)的縮寫)。

from numba import njit

@njit
def bytaintill(l):
    changed = False
    for i in range(len(l) - 1):
        if l[i] > l[i+1]:
            l[i], l[i+1] = l[i+1], l[i]
            changed = True
    return changed

@njit
def bubble(l):
    not_done = True
    while not_done:
        not_done = bytaintill(l)
    return

出於基准測試的目的,您可以簡單地注釋掉裝飾器。 如果您希望能夠在 jit-ed 和 python 版本之間來回切換 go,您可以嘗試這樣的操作:

from numba import njit

do_jit = True  # set to True or False

def bytaintill(l):
    changed = False
    for i in range(len(l) - 1):
        if l[i] > l[i+1]:
            l[i], l[i+1] = l[i+1], l[i]
            changed = True
    return changed

def bubble(l):
    not_done = True
    while not_done:
        not_done = bytaintill(l)
    return

if do_jit:
    bytaintill = njit()(bytaintill)
    bubble = njit()(bubble)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM