[英]Is there a nice way to check if numpy array elements are within a range?
我想寫:
assert np.all(0 < a < 2)
其中a
是numpy
數組,但它不起作用。 有什么好方法來寫這個?
您可以使用numpy.logical_and
:
>>> a = np.repeat(1, 10)
>>> np.logical_and(a > 0, a < 2).all()
True
或使用&
。
>>> ((0 < a) & (a < 2)).all()
True
您可以在 NumPy 中使用以下任一方法實現此目的:
import numpy as np
def between_all_and(arr, a, b):
return np.all((arr > a) & (arr < b))
或者:
import numpy as np
def between_and_all(arr, a, b):
return np.all(arr > a) and np.all(arr < b)
(或者,等效地,通過調用np.ndarray.all()
而不是np.all()
)。
請注意, np.all()
可以替換為all()
,這對於較小的輸入可能更快,但在較大的輸入上要慢得多。
雖然它們給出了相同的結果,但它們都具有次優的短路特性:
between_all_and()
("all of and") 將在訪問短路代碼 ( np.all()
) 之前計算arr > a
和arr < b
arraysarr > a
測試之前, between_and_all()
(“and of all”)不會在arr < b
上短路。在隨機分布的 arrays 上,這意味着兩者可能有非常不同的時序。
或者,可以使用 Numba 加速的基於循環的實現:
import numpy as np
import numba as nb
@nb.njit
def between_nb(arr, a, b):
arr = arr.ravel()
for x in arr:
if x <= a or x >= b:
return False
return True
這具有更好的短路特性,並且不會產生潛在的大型臨時 arrays。
可以對包含均勻分布在 [0, 1] 范圍內的隨機數的 arrays (大小為n
)的批次(大小為m
)生成一些基准,以了解哪些方法更快以及速度快多少。
假設在 [0, 1] 范圍內有一組均勻分布的隨機數,如果檢查不同的范圍,則可能產生具有不同短路的情況:
基准產生於:
import pandas as pd
import matplotlib.pyplot as plt
def benchmark(
funcs,
args=None,
kws=None,
ii=range(4, 24),
m=2 ** 15,
is_equal=np.allclose,
seed=0,
unit="ms",
verbose=True
):
labels = [func.__name__ for func in funcs]
units = {"s": 0, "ms": 3, "µs": 6, "ns": 9}
args = tuple(args) if args else ()
kws = dict(kws) if kws else {}
assert unit in units
np.random.seed(seed)
timings = {}
for i in ii:
n = 2 ** i
k = 1 + m // n
if verbose:
print(f"i={i}, n={n}, m={m}, k={k}")
arrs = np.random.random((k, n))
base = np.array([funcs[0](arr, *args, **kws) for arr in arrs])
timings[n] = []
for func in funcs:
res = np.array([func(arr, *args, **kws) for arr in arrs])
is_good = is_equal(base, res)
timed = %timeit -n 8 -r 8 -q -o [func(arr, *args, **kws) for arr in arrs]
timing = timed.best / k
timings[n].append(timing if is_good else None)
if verbose:
print(
f"{func.__name__:>24}"
f" {is_good!s:5}"
f" {timing * (10 ** units[unit]):10.3f} {unit}"
f" {timings[n][0] / timing:5.1f}x")
return timings, labels
如下調用:
funcs = between_all_and, between_and_all, between_all_nb
avg_timings, avg_labels = benchmark(funcs, args=(0.01, 0.99), unit="µs", verbose=False)
wrs_timings, wrs_labels = benchmark(funcs, args=(-1.0, 2.0), unit="µs", verbose=False)
bst_timings, bst_labels = benchmark(funcs, args=(2.0, 3.0), unit="µs", verbose=False)
plot(avg_timings, avg_labels, "Average Case", unit="µs")
plot(wrs_timings, wrs_labels, "Worst Case", unit="µs")
plot(bst_timings, bst_labels, "Best Case", unit="µs")
生產:
這些可以用來猜測在哪個體制下哪個更快。
通常,基於 Numba 的方法不僅效率最高,而且速度最快。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.