[英]using numpy.where on multidimensional arrays
I have a 2D array, each row represents an output of a classifier that classifies some input to 3 categories (array size is 1000 * 3
) : 我有一个2D数组,每一行代表一个分类器的输出,该分类器将某些输入分为3类(数组大小为
1000 * 3
):
0.3 0.3 0.3
0.3 0.3 1.0
1.0 0.3 0.3
0.3 0.3 0.3
0.3 1.0 0.3
...
I want to get a list of all the inputs that the classifier is "not sure" about them. 我想获得分类器“不确定”所有输入的列表。 And I'm defining "not sure" as no category is above 0.8.
我将“不确定”定义为没有类别高于0.8。
To solve it I use : 为了解决这个问题,我使用:
np.where(model1_preds.max(axis=1) < 0.8)
This works great. 这很好。
But now I have 6 classifiers (that have analyzed the same inputs in the same order), and an array 6 * 1000 * 3
representing their results. 但是现在我有6个分类器(以相同的顺序分析了相同的输入),还有一个表示其结果的数组
6 * 1000 * 3
。
I want to find 2 things: 我想找到两件事:
I assume the general direction is something like this : 我认为总体方向是这样的:
np.stack(np.where(model_preds.max(axis=1) < 0.8) for model_preds in all_preds)
But it won't work because python don't know what I mean in the for loop. 但这是行不通的,因为python不知道我在for循环中的意思。
Alternatively to np.where
: 替代
np.where
:
res_all_unsure = preds[:,np.amax(preds, axis=(0,2)) <= 0.8,:]
res_one_unsure = preds[:,preds.max(-1).min(0) <= 0.8,:]
If it is already a 6×1000×3 matrix preds
, you can first np.transpose()
it into a 1000×6×3 matrix. 如果已经是6×1000×3矩阵
preds
,则可以首先将其np.transpose()
转换为1000×6×3矩阵。
y = preds.transpose(1,0,2) # preds is the input matrix, 6x1000x3
Next we can turn it into a 1000×6 matrix where for each experiment and for each classifier, we know whether all the values were less than 0.8
by stating: 接下来,我们可以将其转换为1000×6矩阵,其中对于每个实验和每个分类器,通过声明,我们可以知道所有值是否都小于
0.8
。
y = np.all(y<0.8,axis=2)
Finally we can use another np.all()
to verify where all the classifiers were unsure: 最后,我们可以使用另一个
np.all()
来验证所有分类器不确定的位置:
all_classifiers_unsure = np.where(np.all(y,axis=1)) # all classifiers
Or where any of the classifiers was unsure: 或不确定任何分类器的地方:
any_classifier_unsure = np.where(np.any(y,axis=1)) # any of the classifiers
We can write it shorter like: 我们可以将其写得更短:
experiment_classifier = np.all(preds.transpose(1,0,2) < 0.8,axis=2)
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))
Although I am quite confident, please validate by checking a few indices (ones that are true and ones that are not true). 尽管我很有信心,但请检查一些索引(一个是正确的,另一个是不正确的)来验证。
EDIT 编辑
You can still use your proposed method of .max() < 0.8
, but with axis=2
: 您仍然可以使用
.max() < 0.8
建议方法,但axis=2
:
experiment_classifier = preds.transpose(1,0,2)
.max(axis=2) < 0.8
all_classifiers_unsure = np.where(np.all(experiment_classifier,axis=1))
any_classifier_unsure = np.where(np.any(experiment_classifier,axis=1))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.