简体   繁体   English

在 ndarray 中搜索 one-hot 编码的 label

[英]Search for a one-hot encoded label in ndarray

I have an ndarray called labels with a shape of (6000, 8) .我有一个名为labelsndarray ,形状为(6000, 8) This is 6000 one-hot encoded arrays with 8 categories.这是 6000 个 one-hot 编码的 arrays,有 8 个类别。 I want to search for labels that looks like this:我想搜索如下所示的标签:

[1,0,0,0,0,0,0,0]

and then tried to do like this然后尝试这样做

np.where(labels==[1,0,0,0,0,0,0,0,0])

but this does not produce the expected result但这不会产生预期的结果

You need all along the second axis:您需要沿着第二all轴:

np.where((labels == [1,0,0,0,0,0,0,0]).all(1))

See with this smaller example:看这个更小的例子:

labels = np.array([[1,0,0,1,0,0,0,0], 
                   [0,0,0,0,0,1,1,0], 
                   [1,0,0,0,0,0,0,0], 
                   [0,0,0,0,0,0,0,1]])

(labels == [1,0,0,0,0,0,0,0])

array([[ True,  True,  True, False,  True,  True,  True,  True],
       [False,  True,  True,  True,  True, False, False,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True],
       [False,  True,  True,  True,  True,  True,  True, False]])

Note that the above comparisson simply returns an array of the same shape as labels , since the comparisson has taken place along the rows of labels .请注意,上面的比较只是返回一个与labels形状相同的数组,因为比较是沿着labels行进行的。 You need to aggregate with all , to check whether all elements in a row are True :您需要与all聚合,以检查一行中的所有元素是否为True

(labels == [1,0,0,0,0,0,0,0]).all(1)
 #array([False, False,  True, False])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM