繁体   English   中英

使用掩码对二维数组进行切片

[英]Slice 2D array using mask

假设一个数组

0 = {ndarray: (4,)} [5 0 3 3]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [2 4 7 6]
3 = {ndarray: (4,)} [8 8 1 6]

我想要切片索引,其中epoch_label等于零

[1 1 0 0]

从上面,索引将是第二和第三个索引

  • 备注: epoch_label是一个整数,取值可以是 0,1,2,...

使用masked_where ,这会产生一些东西

[1 1 -- --]

而且,预期的输出应该是

[2 4 7 6]
[8 8 1 6]

但是,使用下面的代码

epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
epoch_com_arr=np.array(epoch_com)
epoch_label=np.random.randint(2, size=Nepochs)
mm=ma.masked_where(epoch_label == 0, epoch_label)
expected_output=np.where(epoch_com_arr[mm,:])

上面的代码片段产生

0 = {ndarray: (14,)} [0 0 0 0 1 1 1 1 2 2 2 3 3 3]
1 = {ndarray: (14,)} [0 1 2 3 0 1 2 3 0 2 3 0 2 3]

这不符合我的意图

或者

expected_output=epoch_com_arr[mm,:]

其中产生

0 = {ndarray: (4,)} [7 9 3 5]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [5 0 3 3]
3 = {ndarray: (4,)} [5 0 3 3]

我可以知道如何解决这个问题吗

 In [242]: Nepochs = 4
 ...: epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
 ...: epoch_com_arr=np.array(epoch_com)
 ...: epoch_label=np.random.randint(2, size=Nepochs)
 ...: mm=np.ma.masked_where(epoch_label == 0, epoch_label)
 ...: expected_output=np.where(epoch_com_arr[mm,:])

查看变量:

In [246]: epoch_com_arr       # a (4,4) array
Out[246]: 
array([[7, 1, 3, 3],
       [5, 6, 7, 8],
       [5, 6, 3, 8],
       [3, 5, 1, 1]])

我不知道您为什么使用“0 = {ndarray: (4,)} [5 0 3 3]”样式的显示。 这不是正常的numpy

我认为制作masked_array没有任何好处:

In [247]: epoch_label
Out[247]: array([0, 0, 1, 0])
In [248]: mm
Out[248]: 
masked_array(data=[--, --, 1, --],
             mask=[ True,  True, False,  True],
       fill_value=999999)

而只是将 0/1 转换为布尔值。 通常,当我们谈论“屏蔽”时,我们的意思是使用布尔数组作为索引,而不是使用np.ma

In [249]: epoch_label.astype(bool)
Out[249]: array([False, False,  True, False])

该布尔值可用于选择arr行,或者“取消选择”它们:

In [250]: epoch_com_arr[epoch_label.astype(bool),:]
Out[250]: array([[5, 6, 3, 8]])
In [251]: epoch_com_arr[~epoch_label.astype(bool),:]
Out[251]: 
array([[7, 1, 3, 3],
       [5, 6, 7, 8],
       [3, 5, 1, 1]])

我不认为np.where在这里有用。 这给出了epoch_com_arr[mm,:]非零项的索引,并且使用 np.ma` 数组进行索引是有问题的。

np.where可用于将epoch_label转换为索引:

In [252]: idx = np.nonzero(epoch_label)   # aka np.where
In [253]: idx
Out[253]: (array([2]),)
In [254]: epoch_com_arr[idx,:]
Out[254]: array([[[5, 6, 3, 8]]])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM