简体   繁体   English

在多维数组上使用numpy.where

[英]using numpy.where on multidimensional arrays

I have a 2D array, each row represents an output of a classifier that classifies some input to 3 categories (array size is 1000 * 3 ) : 我有一个2D数组,每一行代表一个分类器的输出,该分类器将某些输入分为3类(数组大小为1000 * 3 ):

0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...

I want to get a list of all the inputs that the classifier is "not sure" about them. 我想获得分类器“不确定”所有输入的列表。 And I'm defining "not sure" as no category is above 0.8. 我将“不确定”定义为没有类别高于0.8。

To solve it I use : 为了解决这个问题,我使用:

np.where(model1_preds.max(axis=1) < 0.8)

This works great. 这很好。

But now I have 6 classifiers (that have analyzed the same inputs in the same order), and an array 6 * 1000 * 3 representing their results. 但是现在我有6个分类器(以相同的顺序分析了相同的输入),还有一个表示其结果的数组6 * 1000 * 3

I want to find 2 things: 我想找到两件事:

  1. All the inputs that at least one classifier was "not sure" about. 至少一个分类器“不确定”的所有输入。
  2. All the inputs that all the classifier were "not sure" about. 所有分类器都“不确定”的所有输入。

I assume the general direction is something like this : 我认为总体方向是这样的:

np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)

But it won't work because python don't know what I mean in the for loop. 但这是行不通的,因为python不知道我在for循环中的意思。

Alternatively to np.where : 替代np.where

res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]

If it is already a 6×1000×3 matrix preds , you can first np.transpose() it into a 1000×6×3 matrix. 如果已经是6×1000×3矩阵preds ,则可以首先将其np.transpose()转换为1000×6×3矩阵。

y = preds.transpose(1,0,2)  # preds is the input matrix, 6x1000x3

Next we can turn it into a 1000×6 matrix where for each experiment and for each classifier, we know whether all the values were less than 0.8 by stating: 接下来,我们可以将其转换为1000×6矩阵,其中对于每个实验和每个分类器,通过声明,我们可以知道所有值是否都小于0.8

y = np.all(y<0.8,axis=2)

Finally we can use another np.all() to verify where all the classifiers were unsure: 最后,我们可以使用另一个np.all()来验证所有分类器不确定的位置:

all_classifiers_unsure = np.where(np.all(y,axis=1))  # all classifiers

Or where any of the classifiers was unsure: 或不确定任何分类器的地方:

any_classifier_unsure = np.where(np.any(y,axis=1))   # any of the classifiers

We can write it shorter like: 我们可以将其写得更短:

experiment_classifier = np.all(preds.transpose(1,0,2) < 0.8,axis=2)
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))

Although I am quite confident, please validate by checking a few indices (ones that are true and ones that are not true). 尽管我很有信心,但请检查一些索引(一个是正确的,另一个是不正确的)来验证。

EDIT 编辑

You can still use your proposed method of .max() < 0.8 , but with axis=2 : 您仍然可以使用.max() < 0.8建议方法,但axis=2

experiment_classifier = preds.transpose(1,0,2).max(axis=2) < 0.8
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM