简体   繁体   English

在二维数组中将所有非最小值设置为 NaN

[英]Set all non min values to NaN in a 2D array

I have an array (based on deep learning losses).我有一个数组(基于深度学习损失)。 Let's say it looks like this (2 by 10):假设它看起来像这样(2 x 10):

losses = array([[31.27317047, 32.31885147, 31.32924271,  4.22141647, 32.43081665,
                 32.34402466, 31.84317207, 33.15940857, 32.0574379 , 32.89246368],
                [22.79278946,  2.29259634, 23.11773872, 24.65800285,  6.08445358,
                 23.774786  , 23.28055382, 24.63079453, 20.91534042, 24.70134735]])

(for those interested, the 2 corresponds to a deep learning batch dimension (in practise much higher of course) and 10 is the amount of predictions made by the model) (对于那些感兴趣的人,2 对应于深度学习批次维度(在实践中当然要高得多),10 是模型做出的预测量)

I can easily extract the minimum value or the indices of the minimum value with:我可以很容易地提取最小值或最小值的索引:

np.min(losses, axis=1) # lowest values
np.argmin(losses, axis=1) # indices of lowest values

However, I am looking for an efficient way to set all the non -lowest values to NaN values.但是,我正在寻找一种将所有最小值设置为 NaN 值的有效方法。

So in the end the array will look like this:所以最终数组将如下所示:

losses = array([[np.NaN,  np.NaN,     np.NaN,  4.22141647, np.NaN,
                 np.NaN,  np.NaN,     np.NaN,  np.NaN ,    np.NaN],
                [np.NaN,  2.29259634, np.NaN,  np.NaN,     np.NaN,
                 np.NaN,  np.NaN,     np.NaN,  np.NaN,     np.NaN]])

I could use a for loop for this, but I feel NumPy is not built for this, and there should be an efficient way to do this.我可以为此使用 for 循环,但我觉得 NumPy 不是为此而构建的,应该有一种有效的方法来做到这一点。

I took a look at the documentation, but have not found a solution yet.我查看了文档,但尚未找到解决方案。

Does anyone have some suggestions?有人有什么建议吗?

Thanks!谢谢!

You can use boolean indexing and broadcasting:您可以使用布尔索引和广播:

to make a new array:制作一个新数组:

out = np.where(losses == losses.min(1)[:,None], losses, np.nan)

to modify in place:就地修改:

losses[losses != losses.min(1)[:,None]] = np.nan

output:输出:

array([[       nan,        nan,        nan, 4.22141647,        nan,
               nan,        nan,        nan,        nan,        nan],
       [       nan, 2.29259634,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan]])

intermediates:中间体:

losses.min(axis=1)[:,None]
array([[4.22141647],
       [2.29259634]])

losses == losses.min(axis=1)[:,None]
array([[False, False, False,  True, False, False, False, False, False, False],
       [False,  True, False, False, False, False, False, False, False, False]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM