繁体   English   中英

比较 numpy.ndarray 中特定位置的元素

[英]Comparing elements at specific positions in numpy.ndarray

我不知道标题是否描述了我的问题。 我有从 sigmoid 激活 function 获得的浮点数列表。

  outputs = 
  [[0.015161413699388504,
  0.6720218658447266,
  0.0024502829182893038,
  0.21356457471847534,
  0.002232735510915518,
  0.026410426944494247],
 [0.006432057358324528,
  0.0059209042228758335,
  0.9866275191307068,
  0.004609372932463884,
  0.007315939292311668,
  0.010821194387972355],
 [0.02358204871416092,
  0.5838017225265503,
  0.005475651007145643,
  0.012086033821106,
  0.540218658447266,
  0.010054176673293114]] 

为了计算我的指标,我想说如果任何神经元的 output 值大于 0.5,则假设该评论属于 class(多标签问题)。 我可以使用outputs = np.where(np.array(outputs) >= 0.5, 1, 0)轻松做到这一点但是,我想添加一个条件,仅考虑 class#5 和任何其他 class 时的较大值具有 > 0.5 的值(因为 class#5 不能与其他类一起出现)。 这个条件怎么写? 在我的示例中,output 应该是:

[[0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 1 0 0 0 0]]

代替:

[[0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 1 0 0 1 0]]

谢谢,

您可以编写自定义 function,然后您可以使用np.apply_along_axis() function 将其应用于outputs中的每个子数组:

def choose_class(a):
    if (len(np.argwhere(a >= 0.5)) > 1) & (a[4] >= 0.5):
        return np.where(a == a.max(), 1, 0)
    return np.where(a >= 0.5, 1, 0)

outputs = np.apply_along_axis(choose_class, 1, outputs)
outputs
# array([[0, 1, 0, 0, 0, 0],
#       [0, 0, 1, 0, 0, 0],
#       [0, 1, 0, 0, 0, 0]])

另一种解决方案:

# Your example input array
out = np.array([[0.015, 0.672, 0.002, 0.213, 0.002, 0.026],
                [0.006, 0.005, 0.986, 0.004, 0.007, 0.010],
                [0.023, 0.583, 0.005, 0.012, 0.540, 0.010]])

# We get the desired result
val = (out>=0.5)*out//(out.max(axis=1))[:,None]

此解决方案执行以下操作:

  1. 将所有值设置为零 < 0.5
  2. 按行将最大值设置为 1(如果此值 >= 0.5)

对于简单的面具,你不需要np.where

mask = outputs >= 0.5

如果您想要 integer 而不是 boolean:

mask = (outputs >= 0.5).view(np.uint8)

要检查第五列,您需要保留对原始数据的引用。 您可以获得每个相关行中的最大屏蔽值

rows = np.flatnonzero(mask[:, 4])
keep = (outputs[mask] * mask[rows]).argmax()

然后您可以将行清空并仅设置最大值:

mask[rows] = 0
mask[rows, keep] = 1

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM