簡體   English   中英

在numpy中計算超過閾值的數組值的最快方法

[英]Fastest way to count array values above a threshold in numpy

我有一個包含10 ^ 8個浮點的numpy數組,並想計算其中有多少個> =給定閾值。 速度至關重要,因為必須對大量此類陣列執行操作。 到目前為止的參賽者是

np.sum(myarray >= thresh)

np.size(np.where(np.reshape(myarray,-1) >= thresh))

計數矩陣中大於一個值的所有值的答案表明np.where()會更快,但是我發現時序結果不一致。 我的意思是,對於某些實現和布爾條件,np.size(np.where(cond))比np.sum(cond)快,但對於某些情況,則要慢一些。

具體來說,如果大部分條目滿足條件,則np.sum(cond)會明顯更快,但是如果很小一部分(可能小於十分之一),則np.size(np.where(cond))會獲勝。

問題分為兩部分:

  • 還有其他建議嗎?
  • np.size(np.where(cond))花費的時間隨着cond為真的條目數而增加是否有意義?

使用cython可能是一個不錯的選擇。

import numpy as np
cimport numpy as np
cimport cython
from cython.parallel import prange


DTYPE_f64 = np.float64
ctypedef np.float64_t DTYPE_f64_t


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int count_above_cython(DTYPE_f64_t [:] arr_view, DTYPE_f64_t thresh) nogil:

    cdef int length, i, total
    total = 0
    length = arr_view.shape[0]

    for i in prange(length):
        if arr_view[i] >= thresh:
            total += 1

    return total


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def count_above(np.ndarray arr, DTYPE_f64_t thresh):

    cdef DTYPE_f64_t [:] arr_view = arr.ravel()
    cdef int total

    with nogil:
       total =  count_above_cython(arr_view, thresh)
    return total

建議的不同方法的時間。

myarr = np.random.random((1000,1000))
thresh = 0.33

In [6]: %timeit count_above(myarr, thresh)
1000 loops, best of 3: 693 µs per loop

In [9]: %timeit np.count_nonzero(myarr >= thresh)
100 loops, best of 3: 4.45 ms per loop

In [11]: %timeit np.sum(myarr >= thresh)
100 loops, best of 3: 4.86 ms per loop

In [12]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
10 loops, best of 3: 61.6 ms per loop

使用更大的數組:

In [13]: myarr = np.random.random(10**8)

In [14]: %timeit count_above(myarr, thresh)
10 loops, best of 3: 63.4 ms per loop

In [15]: %timeit np.count_nonzero(myarr >= thresh)
1 loops, best of 3: 473 ms per loop

In [16]: %timeit np.sum(myarr >= thresh)
1 loops, best of 3: 511 ms per loop

In [17]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
1 loops, best of 3: 6.07 s per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM