[英]Set all non min values to NaN in a 2D array
I have an array (based on deep learning losses).我有一个数组(基于深度学习损失)。 Let's say it looks like this (2 by 10):
假设它看起来像这样(2 x 10):
losses = array([[31.27317047, 32.31885147, 31.32924271, 4.22141647, 32.43081665,
32.34402466, 31.84317207, 33.15940857, 32.0574379 , 32.89246368],
[22.79278946, 2.29259634, 23.11773872, 24.65800285, 6.08445358,
23.774786 , 23.28055382, 24.63079453, 20.91534042, 24.70134735]])
(for those interested, the 2 corresponds to a deep learning batch dimension (in practise much higher of course) and 10 is the amount of predictions made by the model) (对于那些感兴趣的人,2 对应于深度学习批次维度(在实践中当然要高得多),10 是模型做出的预测量)
I can easily extract the minimum value or the indices of the minimum value with:我可以很容易地提取最小值或最小值的索引:
np.min(losses, axis=1) # lowest values
np.argmin(losses, axis=1) # indices of lowest values
However, I am looking for an efficient way to set all the non -lowest values to NaN values.但是,我正在寻找一种将所有非最小值设置为 NaN 值的有效方法。
So in the end the array will look like this:所以最终数组将如下所示:
losses = array([[np.NaN, np.NaN, np.NaN, 4.22141647, np.NaN,
np.NaN, np.NaN, np.NaN, np.NaN , np.NaN],
[np.NaN, 2.29259634, np.NaN, np.NaN, np.NaN,
np.NaN, np.NaN, np.NaN, np.NaN, np.NaN]])
I could use a for loop for this, but I feel NumPy is not built for this, and there should be an efficient way to do this.我可以为此使用 for 循环,但我觉得 NumPy 不是为此而构建的,应该有一种有效的方法来做到这一点。
I took a look at the documentation, but have not found a solution yet.我查看了文档,但尚未找到解决方案。
Does anyone have some suggestions?有人有什么建议吗?
Thanks!谢谢!
You can use boolean indexing and broadcasting:您可以使用布尔索引和广播:
to make a new array:制作一个新数组:
out = np.where(losses == losses.min(1)[:,None], losses, np.nan)
to modify in place:就地修改:
losses[losses != losses.min(1)[:,None]] = np.nan
output:输出:
array([[ nan, nan, nan, 4.22141647, nan,
nan, nan, nan, nan, nan],
[ nan, 2.29259634, nan, nan, nan,
nan, nan, nan, nan, nan]])
intermediates:中间体:
losses.min(axis=1)[:,None]
array([[4.22141647],
[2.29259634]])
losses == losses.min(axis=1)[:,None]
array([[False, False, False, True, False, False, False, False, False, False],
[False, True, False, False, False, False, False, False, False, False]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.