簡體   English   中英

最好的方法來計算2d numpy數組中包含另一個1d數組的所有元素的所有行?

[英]Best way to count all rows in a 2d numpy array that include all elements of another 1d array?

在2d numpy數組中計算包含另一個1d numpy數組的所有值的行的最佳方法是什么? 第二個數組可以具有比第一個數組的長度更多的列。

elements = np.arange(4).reshape((2, 2))
test_elements = [2, 3]
somefunction(elements, test_elements)

我希望函數返回1。

elements = np.arange(15).reshape((5, 3))

# array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

test_elements = [4, 3]
somefunction(elements, test_elements)

還應該返回1。

必須包含1d數組的所有元素。 如果連續只能找到幾個元素,則不算在內。 因此:

elements = np.arange(15).reshape((5, 3))

# array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

test_elements = [3, 4, 10]
somefunction(elements, test_elements)

還應該返回0。

創建一個找到的元素的布爾數組,然后按行使用,這將避免在同一行中使用多個值,最后通過使用sum來對行進行計數,

np.any(np.isin(elements, test), axis=1).sum()

輸出量

>>> elements
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14]])
>>> test = [1, 6, 7, 4]
>>> np.any(np.isin(elements, test), axis=1).sum()
3

編輯 :好的,現在我實際上有更多時間弄清楚發生了什么。)


這里有兩個問題:

  1. 計算復雜度取決於兩個輸入的大小,並且不能被一維基准圖很好地捕獲
  2. 實際時序受輸入變化的影響

該問題可以分為兩個部分:

  1. 遍歷行
  2. 執行子集檢查,這基本上是一個嵌套循環二次運算(在最壞的情況下)

我們知道, 對於足夠大的輸入 ,在NumPy中通過行循環更快 ,而在純Python中則更

作為參考,讓我們考慮以下兩種方法:

# pure Python approach
def all_in_by_row_flt(arr, elems=ELEMS):
    return sum(1 for row in arr if all(e in row for e in elems))

# NumPy apprach (based on @Mstaino answer)
def all_in_by_row_np(arr, elems=ELEMS):
    def _aaa_helper(row, e=elems):
        return np.isin(e, row)
    return np.sum(np.all(np.apply_along_axis(_aaa_helper, 1, arr), 1))

然后,考慮子集檢查操作,如果輸入是在更少的循環內執行檢查,則純Python循環比NumPy更快。 相反,如果需要足夠多的循環,則NumPy實際上可以更快。 最重要的是,存在遍歷行的循環,但是由於子集檢查操作是二次的並且具有不同的常數系數,因此盡管在NumPy中行循環更快,但在某些情況下(因為行數會足夠大), 在純Python中整體操作更快。 這是我在較早的基准測試中遇到的情況,並且對應於子集檢查始終(或幾乎)為False且在幾個循環中確實失敗的情況。 一旦子集檢查開始需要更多的循環,Python的唯一方法開始落后以及其中該子集檢查實際上的情況True大部分(如果不是全部)行,在NumPy的方法實際上要快

NumPy和純Python方法之間的另一個關鍵區別是,純Python使用惰性求值,而NumPy不使用惰性求值,並且實際上需要創建可能會減慢計算速度的潛在大型中間對象。 最重要的是,NumPy對行進行兩次迭代( sum()一個, np.apply_along_axis() ),而純Python僅處理一次。


使用set().issubset()其他方法,例如來自@ GZ0的答案:

def all_in_by_row_set(arr, elems=ELEMS):
    elems = set(elems)
    return sum(map(elems.issubset, row))

與進行子集檢查時顯式編寫嵌套循環相比,它們具有不同的時序,但是它們仍然受較慢的外部循環的影響。


下一個是什么?

答案是使用CythonNumba 這樣做的想法是始終保持類似於NumPy的速度(讀取:C)(不僅對於足夠大的輸入),懶惰的求值和最少的循環行數。

Cython方法的一個示例(在IPython中使用%load_ext Cython魔術實現):

%%cython --cplus -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True


cdef long all_in_by_row_c(long[:, :] arr, long[:] elems) nogil:
    cdef long result = 0
    I = arr.shape[0]
    J = arr.shape[1]
    K = elems.shape[0]
    for i in range(I):
        is_subset = True
        for k in range(K):
            is_contained = False
            for j in range(J):
                if elems[k] == arr[i, j]:
                    is_contained = True
                    break
            if not is_contained:
                is_subset = False
                break
        result += 1 if is_subset else 0
    return result


def all_in_by_row_cy(long[:, :] arr, long[:] elems):
    return all_in_by_row_c(arr, elems)

雖然類似的Numba代碼顯示:

import numba as nb


@nb.jit(nopython=True, nogil=True)
def all_in_by_row_jit(arr, elems=ELEMS):
    result = 0
    n_rows, n_cols = arr.shape
    for i in range(n_rows):
        is_subset = True
        for e in elems:
            is_contained = False
            for r in arr[i, :]:
                if e == r:
                    is_contained = True
                    break
            if not is_contained:
                is_subset = False
                break
        result += 1 if is_subset else 0
    return result

現在,就時間而言,我們可以了解以下內容(相對較少的行數):

arr.shape=(100, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  120 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 129 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 2.44 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_set 9.98 ms ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np  13.7 ms ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

arr.shape=(100, 2000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  1.45 ms ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_jit 1.52 ms ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_flt 30.1 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_set 19.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np  18 ms ± 28.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

arr.shape=(100, 3000) elems.shape=(1000,) result=37
Func: all_in_by_row_cy  10.4 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 10.9 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 226 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30.5 ms ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np  21.9 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

arr.shape=(100, 4000) elems.shape=(1000,) result=86
Func: all_in_by_row_cy  16.8 ms ± 32.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 17.7 ms ± 42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 385 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 39.5 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np  25.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

現在,無法通過第二維中輸入大小的增加來解釋最后一個塊的變慢。 實際上,如果提高短路率 (例如,通過更改隨機數組的值范圍),則對於最后一塊(輸入大小相同),將得到:

arr.shape=(100, 4000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy   152 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit  173 µs ± 4.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt  556 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_set  39.7 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np   31.5 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

請注意,基於set()的方法在某種程度上與短路率無關(因為基於哈希的實現具有~O(1)檢查存在性的復雜性,但這是以哈希進行預計算和這些結果表明,這可能不會比直接嵌套循環方法快。

最后,對於更大的行數:

arr.shape=(100000, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  141 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_jit 150 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_flt 2.6 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 10.1 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  13.7 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 2000) elems.shape=(1000,) result=34
Func: all_in_by_row_cy  1.2 s ± 753 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 1.27 s ± 7.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 24.1 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 19.5 s ± 270 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  18 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 3000) elems.shape=(1000,) result=33859
Func: all_in_by_row_cy  9.79 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 10.3 s ± 5.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 3min 30s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30 s ± 57.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  21.9 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 4000) elems.shape=(1000,) result=86376
Func: all_in_by_row_cy  17 s ± 30.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 17.9 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 6min 29s ± 293 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 38.9 s ± 33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  25.7 s ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

最后,請注意, 可以通過算法優化Cython / Numba代碼。

可能有一個更有效的解決方案,但是如果您想要在其中包含test_elements “所有”元素的行,則可以反轉np.isin並將其沿着每一行應用,如下所示:

np.apply_along_axis(lambda x: np.isin(test_elements, x), 1, elements).all(1).sum()

以下是@ norok2解決方案的一個稍微更有效(但可讀性更差)的變體。

sum(map(set(test_elements).issubset, elements))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM