简体   繁体   English

Numpy:从 nD 数组中提取最小值的索引

[英]Numpy: extract indices of the min values from the n-D array

I have an array我有一个数组

X = [[0.65108716 0.72213542 0.62142414 0.80734795 0.79485172 0.83946013
0.79192978 0.76614672 0.         0.84231442]
 [0.71353155 0.58493483 0.76903558 0.77678972 0.71837986 0.56127471
0.72591233 0.75986564 0.83495295 0.03016315]]

I need to extract indices of the top k min values from each subarray.我需要从每个子数组中提取前 k 最小值的索引。

If I want to get only 1 value, I can use如果我只想获得 1 个值,我可以使用

top_n_indices = np.argsort(X)[:, :1]

It gives me indices它给了我索引

[[8], [9]]

And then when I try to extract values with然后当我尝试用

np.take(X, top_n_indices)

It gives me incorrect answer它给了我不正确的答案

[[0.        ], [0.84231442]]

But it should be但应该是

[[0.        ], [0.03016315]]

Is it possible to do that without list comprehension?没有列表理解可以做到这一点吗?

At the end of the day, do you want the k minimum values from each subarray, or do you need their positions as well?归根结底,您是想要每个子数组的 k 个最小值,还是还需要它们的位置? If you only need the values, try the following (eg, to get the 3 lowest values in each subarray):如果您只需要这些值,请尝试以下操作(例如,获取每个子数组中的 3 个最低值):

import numpy as np

X = np.array([[0.65108716,0.72213542,0.62142414,0.80734795,0.79485172,0.83946013,0.79192978,0.76614672,0.,0.84231442],
[0.71353155,0.58493483,0.76903558,0.77678972,0.71837986,0.56127471,0.72591233,0.75986564,0.83495295,0.03016315]])

k=3
k_min_values=np.sort(X,axis=1)[:,:k]

print(k_min_values)

which will yield:这将产生:

[[0.         0.62142414 0.65108716]
 [0.03016315 0.56127471 0.58493483]]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM