簡體   English   中英

測試 numpy 數組中的每個元素是否位於兩個值之間的簡單方法?

[英]Easy way to test if each element in an numpy array lies between two values?

我想知道是否有一種語法上簡單的方法來檢查 numpy 數組中的每個元素是否位於兩個數字之間。

換句話說,就像numpy.array([1,2,3,4,5]) < 5將返回array([True, True, True, True, False]) ,我想知道是否可以這樣做類似於這樣的東西:

1 < numpy.array([1,2,3,4,5]) < 5

……獲得……

array([False, True, True, True, False])

我知道我可以通過 boolean 測試的邏輯鏈接來獲得這一點,但我正在處理一些相當復雜的代碼,並且我正在尋找一個語法上干凈的解決方案。

有小費嗎?

一種解決方案是:

import numpy as np
a = np.array([1, 2, 3, 4, 5])
(a > 1).all() and (a < 5).all()
# False

如果您想要真值數組,請使用:

(a > 1) & (a < 5)
# array([False,  True,  True,  True, False])

另一種是使用numpy.any ,這是一個例子

import numpy as np
a = np.array([1,2,3,4,5])
np.any((a < 1)|(a > 5 ))

您還可以將矩陣居中並使用到 0 的距離

upper_limit = 5
lower_limit = 1
a = np.array([1,2,3,4,5])
your_mask = np.abs(a- 0.5*(upper_limit+lower_limit))<0.5*(upper_limit-lower_limit)

要記住的一件事是,比較將在兩側對稱,因此它可以執行1<x<51<=x<=5 ,但不能執行1<=x<5

在多維數組中,您可以使用建議的np.any()選項或比較運算符,而使用&and會引發錯誤。

使用比較運算符的示例(在多維數組上)

import numpy as np

arr = np.array([[1,5,1],
                [0,1,0],
                [0,0,0],
                [2,2,2]])

現在使用==如果要檢查數組值是否在范圍內,即 A < arr < B,或!=如果要檢查數組值是否在范圍外,即 arr < A 和 arr > B :

(arr<1) != (arr>3)
> array([[False,  True, False],
         [ True, False,  True],
         [ True,  True,  True],
         [False, False, False]])

(arr>1) == (arr<4)
> array([[False, False, False],
         [False, False, False],
         [False, False, False],
         [ True,  True,  True]])

將基於 NumPy 的方法與 Numba 加速循環進行比較很有趣:

import numpy as np
import numba as nb


def between(arr, a, b):
    return (arr > a) & (arr < b)


@nb.njit(fastmath=True)
def between_nb(arr, a, b):
    shape = arr.shape
    arr = arr.ravel()
    n = arr.size
    result = np.empty_like(arr, dtype=np.bool_)
    for i in range(n):
        result[i] = arr[i] > a or arr[i] < b
    return result.reshape(shape)

基准計算和繪制:

import pandas as pd
import matplotlib.pyplot as plt


def benchmark(
    funcs,
    args=None,
    kws=None,
    ii=range(4, 24),
    m=2 ** 15,
    is_equal=np.allclose,
    seed=0,
    unit="ms",
    verbose=True
):
    labels = [func.__name__ for func in funcs]
    units = {"s": 0, "ms": 3, "µs": 6, "ns": 9}
    args = tuple(args) if args else ()
    kws = dict(kws) if kws else {}
    assert unit in units
    np.random.seed(seed)
    timings = {}
    for i in ii:
        n = 2 ** i
        k = 1 + m // n
        if verbose:
            print(f"i={i}, n={n}, m={m}, k={k}")
        arrs = np.random.random((k, n))
        base = np.array([funcs[0](arr, *args, **kws) for arr in arrs])
        timings[n] = []
        for func in funcs:
            res = np.array([func(arr, *args, **kws) for arr in arrs])
            is_good = is_equal(base, res)
            timed = %timeit -n 8 -r 8 -q -o [func(arr, *args, **kws) for arr in arrs]
            timing = timed.best / k
            timings[n].append(timing if is_good else None)
            if verbose:
                print(
                    f"{func.__name__:>24}"
                    f"  {is_good!s:5}"
                    f"  {timing * (10 ** units[unit]):10.3f} {unit}"
                    f"  {timings[n][0] / timing:5.1f}x")
    return timings, labels
funcs = between, between_nb
timings, labels = benchmark(funcs, args=(0.25, 0.75), unit="µs", verbose=False)
plot(timings, labels, unit="µs")

導致: 在此處輸入圖像描述

表明(在我的測試條件下):

  • 對於更大和更小的輸入,Numba 方法可以快 20%
  • 對於中等大小的輸入,NumPy 方法通常更快

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM