[英]How to create a numpy array with a an extra dimension depending on an where clause?
我需要创建一个采用 argmax 的数组,并基于该最大值 position 用[1,0]
填充数组,而其他不是最大值的字段将用[0,1]
填充。
给定向量a
:
a.shape = (3,2)
a = np.array([[1,0],[1,2],[1,3]])
返回向量b
:
b.shape = (3,2,2)
b = np.array([[[1,0],[0,1]],[[0,1],[1,0]],[[0,1],[1,0]]])
c = np.argmax(a, axis=1)
b = np.empty(tuple(list(a.shape) + [2]))
b[range(len(c)), c, :] = [1, 0]
b[range(len(c)), ~c, :] = [0, 1]
b
>>>array([[[1., 0.],
[0., 1.]],
[[0., 1.],
[1., 0.]],
[[0., 1.],
[1., 0.]]])
请注意,这仅在此示例中有效,因为 argmax 将永远只有 0 或 1。如果a
中的第二维大于 2,我认为此解决方案不会起作用
我能够创建一个 function 来返回理想的结果,但仅适用于两个类。 它可以适用于多个类:
a = np.array([[1,0],[1,2],[1,3]])
def create_dist_prob_target(arr):
p_ = np.squeeze(arr,axis=1)
a = np.expand_dims(np.where((p_ == np.amax(p_,axis = 1)[:,None]),1,0),axis=-1)
b = np.expand_dims(np.where((p_ == np.amax(p_,axis = 1)[:,None]),0,1),axis=-1)
return np.concatenate((a,b),axis=2)
b = create_dist_prob_target(a)
print(b)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.