[英]Search for a one-hot encoded label in ndarray
I have an ndarray
called labels
with a shape of (6000, 8)
.我有一个名为
labels
的ndarray
,形状为(6000, 8)
。 This is 6000 one-hot encoded arrays with 8 categories.这是 6000 个 one-hot 编码的 arrays,有 8 个类别。 I want to search for labels that looks like this:
我想搜索如下所示的标签:
[1,0,0,0,0,0,0,0]
and then tried to do like this然后尝试这样做
np.where(labels==[1,0,0,0,0,0,0,0,0])
but this does not produce the expected result但这不会产生预期的结果
You need all
along the second axis:您需要沿着第二
all
轴:
np.where((labels == [1,0,0,0,0,0,0,0]).all(1))
See with this smaller example:看这个更小的例子:
labels = np.array([[1,0,0,1,0,0,0,0],
[0,0,0,0,0,1,1,0],
[1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,1]])
(labels == [1,0,0,0,0,0,0,0])
array([[ True, True, True, False, True, True, True, True],
[False, True, True, True, True, False, False, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, False]])
Note that the above comparisson simply returns an array of the same shape as labels
, since the comparisson has taken place along the rows of labels
.请注意,上面的比较只是返回一个与
labels
形状相同的数组,因为比较是沿着labels
行进行的。 You need to aggregate with all
, to check whether all elements in a row are True
:您需要与
all
聚合,以检查一行中的所有元素是否为True
:
(labels == [1,0,0,0,0,0,0,0]).all(1)
#array([False, False, True, False])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.