[英]How to check if a list of numpy arrays contains a given test array?
我有一个numpy
数组的列表,例如,
a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
我有一个测试阵列
b = np.random.rand(3, 3)
我想检查a
是否包含b
。 然而
b in a
引发以下错误:
ValueError:具有多个元素的数组的真值不明确。 使用a.any()或a.all()
我想要什么的正确方法是什么?
你可以只让形状的一个阵列(3, 3, 3)
出的a
:
a = np.asarray(a)
然后将其与b
比较(我们在这里比较浮点数,因此我们应该使用isclose()
)
np.all(np.isclose(a, b), axis=(1, 2))
例如:
a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...] # set b to some value we know will yield True
np.all(np.isclose(a, b), axis=(1, 2))
# array([False, True, False])
好吧in
因为它有效
def in_(obj, iterable):
for elem in iterable:
if obj == elem:
return True
return False
现在的问题是,对于两个ndarray a
和b
, a == b
是一个数组(尝试),而不是布尔值,因此if a == b
失败。 解决方案是定义一个新功能
def array_in(arr, list_of_arr):
for elem in list_of_arr:
if (arr == elem).all():
return True
return False
a = [np.arange(5)] * 3
b = np.ones(5)
array_in(b, a) # --> False
该错误是因为,如果a
和b
是numpy arrays
则a == b
不会返回True
或False
,而是在a
元素比较a
和b
之后返回boolean
值的array
。
您可以尝试如下操作:
np.any([np.all(a_s == b) for a_s in a])
[np.all(a_s == b) for a_s in a]
这里所创建的列表boolean
值,通过的元件迭代a
和检查是否在所有的元素b
的和特定元件a
是相同的。
使用np.any
您可以检查数组中的任何元素是否为True
对于列表,元组,集合,frozenset,dict或collections.deque等容器类型,y中的表达式x等于any(x为e或y中的e为x == e)。
a[0]==b
是一个数组,其中包含a[0]
和b
逐元素比较。 该数组的整体真值显然是不明确的。 如果所有元素都匹配,或者如果至少一个匹配,则大多数匹配,它们是否相同? 因此, numpy
迫使您明确表达自己的意思。 您想知道的是测试所有元素是否相同。 您可以使用numpy
的all
方法来做到这一点:
any((b is e) or (b == e).all() for e in a)
或加入一个函数:
def numpy_in(arrayToTest, listOfArrays):
return any((arrayToTest is e) or (arrayToTest == e).all()
for e in listOfArrays)
使用numpy中的array_equal
import numpy as np
a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
b = np.random.rand(3,3)
for i in a:
if np.array_equal(b,i):
print("yes")
正如@jotasi所强调的,由于数组中元素之间的比较,真值是不明确的。 有一个以前这个问题的答案在这里 。 总体而言,您的任务可以通过多种方式完成:
您可以通过将列表转换为(3,3,3)形状的数组来使用“ in”运算符,如下所示:
>>> a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
>>> a= np.asarray(a)
>>> b= a[1].copy()
>>> b in a
True
np.all:
>>> any(np.all((b==a),axis=(1,2))) True
list-comperhension:通过遍历每个数组来完成:
>>> any([(b == a_s).all() for a_s in a]) True
下面是上述三种方法的速度比较:
import numpy as np
import perfplot
perfplot.show(
setup=lambda n: np.asarray([np.random.rand(3*3).reshape(3,3) for i in range(n)]),
kernels=[
lambda a: a[-1] in a,
lambda a: any(np.all((a[-1]==a),axis=(1,2))),
lambda a: any([(a[-1] == a_s).all() for a_s in a])
],
labels=[
'in', 'np.all', 'list_comperhension'
],
n_range=[2**k for k in range(1,20)],
xlabel='Array size',
logx=True,
logy=True,
)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.