簡體   English   中英

有沒有一種很好的方法來檢查 numpy 數組元素是否在一個范圍內?

[英]Is there a nice way to check if numpy array elements are within a range?

我想寫:

assert np.all(0 < a < 2)

其中anumpy數組,但它不起作用。 有什么好方法來寫這個?

您可以使用numpy.logical_and

>>> a = np.repeat(1, 10)
>>> np.logical_and(a > 0, a < 2).all()
True

或使用&

>>> ((0 < a) & (a < 2)).all()
True

您可以在 NumPy 中使用以下任一方法實現此目的:

import numpy as np


def between_all_and(arr, a, b):
    return np.all((arr > a) & (arr < b))

或者:

import numpy as np


def between_and_all(arr, a, b):
    return np.all(arr > a) and np.all(arr < b)

(或者,等效地,通過調用np.ndarray.all()而不是np.all() )。

請注意, np.all()可以替換為all() ,這對於較小的輸入可能更快,但在較大的輸入上要慢得多。

雖然它們給出了相同的結果,但它們都具有次優的短路特性:

  • between_all_and() ("all of and") 將在訪問短路代碼 ( np.all() ) 之前計算arr > aarr < b arrays
  • 在執行所有arr > a測試之前, between_and_all() (“and of all”)不會在arr < b上短路。

在隨機分布的 arrays 上,這意味着兩者可能有非常不同的時序。

或者,可以使用 Numba 加速的基於循環的實現:

import numpy as np
import numba as nb


@nb.njit
def between_nb(arr, a, b):
    arr = arr.ravel()
    for x in arr:
        if x <= a or x >= b:
            return False
    return True

這具有更好的短路特性,並且不會產生潛在的大型臨時 arrays。

可以對包含均勻分布在 [0, 1] 范圍內的隨機數的 arrays (大小為n )的批次(大小為m )生成一些基准,以了解哪些方法更快以及速度快多少。


基准

假設在 [0, 1] 范圍內有一組均勻分布的隨機數,如果檢查不同的范圍,則可能產生具有不同短路的情況:

  • (0.0,0.999)等范圍的“平均情況”
  • (-1.0, 2.0)等范圍的“最壞情況”(無短路)
  • (2.0, 3.0)等范圍的“最佳情況”(可能立即短路)

基准產生於:

import pandas as pd
import matplotlib.pyplot as plt


def benchmark(
    funcs,
    args=None,
    kws=None,
    ii=range(4, 24),
    m=2 ** 15,
    is_equal=np.allclose,
    seed=0,
    unit="ms",
    verbose=True
):
    labels = [func.__name__ for func in funcs]
    units = {"s": 0, "ms": 3, "µs": 6, "ns": 9}
    args = tuple(args) if args else ()
    kws = dict(kws) if kws else {}
    assert unit in units
    np.random.seed(seed)
    timings = {}
    for i in ii:
        n = 2 ** i
        k = 1 + m // n
        if verbose:
            print(f"i={i}, n={n}, m={m}, k={k}")
        arrs = np.random.random((k, n))
        base = np.array([funcs[0](arr, *args, **kws) for arr in arrs])
        timings[n] = []
        for func in funcs:
            res = np.array([func(arr, *args, **kws) for arr in arrs])
            is_good = is_equal(base, res)
            timed = %timeit -n 8 -r 8 -q -o [func(arr, *args, **kws) for arr in arrs]
            timing = timed.best / k
            timings[n].append(timing if is_good else None)
            if verbose:
                print(
                    f"{func.__name__:>24}"
                    f"  {is_good!s:5}"
                    f"  {timing * (10 ** units[unit]):10.3f} {unit}"
                    f"  {timings[n][0] / timing:5.1f}x")
    return timings, labels

如下調用:

funcs = between_all_and, between_and_all, between_all_nb

avg_timings, avg_labels = benchmark(funcs, args=(0.01, 0.99), unit="µs", verbose=False)
wrs_timings, wrs_labels = benchmark(funcs, args=(-1.0, 2.0), unit="µs", verbose=False)
bst_timings, bst_labels = benchmark(funcs, args=(2.0, 3.0), unit="µs", verbose=False)
plot(avg_timings, avg_labels, "Average Case", unit="µs")
plot(wrs_timings, wrs_labels, "Worst Case", unit="µs")
plot(bst_timings, bst_labels, "Best Case", unit="µs")

生產:

bm_avg

bm_wrs

bm_bst

這些可以用來猜測在哪個體制下哪個更快。

通常,基於 Numba 的方法不僅效率最高,而且速度最快。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM