繁体   English   中英

如何根据 where 子句创建具有额外维度的 numpy 数组?

[英]How to create a numpy array with a an extra dimension depending on an where clause?

问题

我需要创建一个采用 argmax 的数组,并基于该最大值 position 用[1,0]填充数组,而其他不是最大值的字段将用[0,1]填充。

例子:

给定向量a

a.shape = (3,2)
a = np.array([[1,0],[1,2],[1,3]])

返回向量b

b.shape = (3,2,2)
b = np.array([[[1,0],[0,1]],[[0,1],[1,0]],[[0,1],[1,0]]])
c = np.argmax(a, axis=1)
b = np.empty(tuple(list(a.shape) + [2]))
b[range(len(c)), c, :] = [1, 0]
b[range(len(c)), ~c, :] = [0, 1]
b
>>>array([[[1., 0.],
           [0., 1.]],

          [[0., 1.],
           [1., 0.]],

          [[0., 1.],
           [1., 0.]]])

请注意,这仅在此示例中有效,因为 argmax 将永远只有 0 或 1。如果a中的第二维大于 2,我认为此解决方案不会起作用

我能够创建一个 function 来返回理想的结果,但仅适用于两个类。 它可以适用于多个类:

a = np.array([[1,0],[1,2],[1,3]])

def create_dist_prob_target(arr):
    p_ = np.squeeze(arr,axis=1)
    a = np.expand_dims(np.where((p_ == np.amax(p_,axis = 1)[:,None]),1,0),axis=-1)
    b = np.expand_dims(np.where((p_ == np.amax(p_,axis = 1)[:,None]),0,1),axis=-1)
    return np.concatenate((a,b),axis=2)

b = create_dist_prob_target(a)
print(b)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM