簡體   English   中英

Numba - 無法推斷 List() 的類型

[英]Numba - cannot infer type for List()

我正在嘗試使用 numba 來加速 python package 的模糊搜索 function。 我的計划是先按順序使用 njit,如果沒有達到目標,再轉向並行。 所以我將庫中的原始 function 轉換為支持 numba 的類型。 我使用的是類型列表而不是普通的 python 列表。 numba 拋出錯誤"Cannot infer the type of variable 'candidates', have imprecise type: ListType[undefined]" 我很困惑為什么會出現這個錯誤? 這不是聲明類型列表變量的方式嗎?

我是 numba 的新手,因此歡迎任何有關加速此過程的替代有效方法的建議。

@njit
def make_char2first_subseq_index(subsequence, max_l_dist):
    d = Dict.empty(
        key_type=types.unicode_type,
        value_type=numba.int64,
    )
    for (index, char) in list(enumerate(subsequence[:max_l_dist + 1])):
        d[char] = index
    return d


@njit
def find_near_matches_levenshtein_linear_programming(subsequence, sequence,
                                                     max_l_dist):
    if not subsequence:
        raise ValueError('Given subsequence is empty!')

    subseq_len = len(subsequence)

    def make_match(start, end, dist):
        # return Match(start, end, dist, matched=sequence[start:end])
        return str(start) + " " + str(end) + " " + str(dist) + " " + str(sequence[start:end])

    if max_l_dist >= subseq_len:
        for index in range(len(sequence) + 1):
            return make_match(index, index, subseq_len)

    # optimization: prepare some often used things in advance
    char2first_subseq_index = make_char2first_subseq_index(subsequence,
                                                           max_l_dist)

    candidates = List()
    for index, char in enumerate(sequence):
        # print('/n new loop and the character is ', char)
        new_candidates = List()

        idx_in_subseq = char2first_subseq_index.get(char, None)
        # print("idx_in_subseq ", idx_in_subseq)
        if idx_in_subseq is not None:
            if idx_in_subseq + 1 == subseq_len:
                return make_match(index, index + 1, idx_in_subseq)
            else:
                new_candidates.append(List(index, idx_in_subseq + 1, idx_in_subseq))

        # print(candidates, " new candidates ", new_candidates)
        for cand in candidates:
            # if this sequence char is the candidate's next expected char
            if subsequence[cand[1]] == char:
                # if reached the end of the subsequence, return a match
                if cand[1] + 1 == subseq_len:
                    return make_match(cand[0], index + 1, cand[2])
                # otherwise, update the candidate's subseq_index and keep it
                else:
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2]))

            # if this sequence char is *not* the candidate's next expected char
            else:
                # we can try skipping a sequence or sub-sequence char (or both),
                # unless this candidate has already skipped the maximum allowed
                # number of characters
                if cand[2] == max_l_dist:
                    continue

                # add a candidate skipping a sequence char
                new_candidates.append(List(cand[0], cand[1], cand[2] + 1))

                if index + 1 < len(sequence) and cand[1] + 1 < subseq_len:
                    # add a candidate skipping both a sequence char and a
                    # subsequence char
                    new_candidates.append(List(cand[0], cand[1] + 1, cand[2] + 1))

                # try skipping subsequence chars
                for n_skipped in range(1, max_l_dist - cand[2] + 1):
                    # if skipping n_skipped sub-sequence chars reaches the end
                    # of the sub-sequence, yield a match
                    if cand[1] + n_skipped == subseq_len:
                        return make_match(cand[0], index + 1, cand[2] + n_skipped)
                        break
                    # otherwise, if skipping n_skipped sub-sequence chars
                    # reaches a sub-sequence char identical to this sequence
                    # char, add a candidate skipping n_skipped sub-sequence
                    # chars
                    elif subsequence[cand[1] + n_skipped] == char:
                        # if this is the last char of the sub-sequence, yield
                        # a match
                        if cand[1] + n_skipped + 1 == subseq_len:
                            return make_match(cand[0], index + 1,
                                             cand[2] + n_skipped)
                        # otherwise add a candidate skipping n_skipped
                        # subsequence chars
                        else:
                            new_candidates.append(List(cand[0], cand[1] + 1 + n_skipped, cand[2] + n_skipped))
                        break
                # note: if the above loop ends without a break, that means that
                # no candidate could be added / yielded by skipping sub-sequence
                # chars

        candidates = new_candidates

    for cand in candidates:
        dist = cand[2] + subseq_len - cand[1]
        if dist <= max_l_dist:
            return make_match(cand[0], len(sequence), dist)

錯誤消息非常准確,它指向了具體問題。 Numba typed.List使用同構數據類型,因此它需要知道類型。

您可以通過初始化來創建類型列表:

list_of_ints = nb.typed.List([1,2,3])

或者使用empty_list()工廠創建一個空的來聲明它的類型:

empty_list_of_floats = nb.typed.List.empty_list(nb.f8)

或者創建一個空的並立即附加一個元素:

another_list_of_ints = nb.typed.List()
another_list_of_ints.append(1)

或任何組合:

list_of_lists_of_floats = nb.typed.List()
list_of_lists_of_floats.append(nb.typed.List.empty_list(nb.f8))
list_of_lists_of_floats[0].append(1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM