簡體   English   中英

如何迭代 cython(或 numba)中的列表列表?

[英]How to iterate over a list of lists in cython (or numba)?

我想要一個 function 作為參數接收列表列表,每個子列表具有不同的大小,並且可以迭代每個子列表(包含整數),將它們作為廣播傳遞給 numpy 數組並執行不同的操作(如平均值)。

讓我包括一個不使用 cython 的預期行為的簡單示例:

import numpy as np

mask = [[0, 1, 2, 4, 6, 7, 8, 9],
        [0, 1, 2, 4, 6, 7, 8, 9],
        [0, 1, 2, 4, 6, 9],
        [3, 5, 8],
        [0, 1, 2, 4, 6, 7, 8, 9],
        [3, 5, 7],
        [0, 1, 2, 4, 6, 9],
        [0, 1, 4, 5, 7, 8, 9],
        [0, 1, 3, 4, 7, 8, 9],
        [0, 1, 2, 4, 6, 7, 8, 9]] # This is the list of lists

x = np.array([2.0660689 , 2.08599832, 0.45032649, 1.05435649, 2.06010132,
              1.07633407, 0.43014785, 1.54286467, 1.644388  , 2.15417444])

def nocython(mask, x):
    out = np.empty(len(x), dtype=np.float64)
    for i, v in enumerate(mask):
        out[i] = x[v].mean()
    return out

>>> nocython(mask, x)
array([1.55425875, 1.55425875, 1.54113622, 1.25835952, 1.55425875,
       1.22451841, 1.54113622, 1.80427567, 1.80113602, 1.55425875])

主要問題是我必須處理更大的 numpy arrays 和掩碼列表,並且循環在 Python 中變得非常低效。 所以我想知道如何對這個 function 進行 cythonize(或麻木)。 像這樣的東西:

%%cython

import numpy as np
cimport numpy as np

cdef np.ndarray[np.float64_t] cythonloop(int[:,:] mask, np.ndarray[np.float64_t] x):
    cdef Py_ssize_t i
    cdef Py_ssize_t N = len(x)
    cdef np.ndarray[np.float64_t] out = np.empty(N, dtype=np.float64)
    for i in range(N):
        out[i] = x[mask[i]]

cythonloop(mask, x)

但這不起作用(不能強制列表輸入'int [:, :]')。

如果我在 numba 中嘗試也不會

import numba as nb

@nb.njit
def nocython(mask, x):
    out = np.empty(len(x), dtype=np.float64)
    for i, v in enumerate(mask):
        out[i] = x[v].mean()
    return out

這給出了以下錯誤:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 1d, A), reflected list(int64))
 * parameterized

在 Numba 中,您可以使用Typed List對列表列表進行迭代。 Numba 不支持使用列表索引 NumPy 數組,因此 function 還需要進行一些修改以通過迭代內部列表的元素並索引到x來實現平均值。

在調用 jitted function 之前,您還需要將列表列表轉換為類型列表的類型列表。

把它放在一起給出(除了你的問題的代碼):

from numba import njit
from numba.typed import List

@njit
def jitted(mask, x): 
    out = np.empty(len(x), dtype=np.float64)
    for i in range(len(mask)):
        m_i = mask[i]
        s = 0 
        for j in range(len(m_i)):
            s += x[m_i[j]]
        out[i] = s / len(m_i)
    return out 

typed_mask = List()
for m in mask:
    typed_mask.append(List(m))

# Sanity check - Numba and nocython implementations produce the same result
np.testing.assert_allclose(nocython(mask, x),  jitted(typed_mask, x))

請注意,也可以避免將列表設置為類型化列表,因為 Numba 在傳遞內置列表類型時將使用反射列表- 但是此功能已棄用並將從 Numba 的未來版本中刪除,因此建議改用 Typed List。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM