[英]Writing a 3D numpy array into a slice of a larger 3D array using a 2D mask
[英]Slice 2D array using mask
假设一个数组
0 = {ndarray: (4,)} [5 0 3 3]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [2 4 7 6]
3 = {ndarray: (4,)} [8 8 1 6]
我想要切片索引,其中epoch_label
等于零
[1 1 0 0]
从上面,索引将是第二和第三个索引
epoch_label
是一个整数,取值可以是 0,1,2,... 使用masked_where
,这会产生一些东西
[1 1 -- --]
而且,预期的输出应该是
[2 4 7 6]
[8 8 1 6]
但是,使用下面的代码
epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
epoch_com_arr=np.array(epoch_com)
epoch_label=np.random.randint(2, size=Nepochs)
mm=ma.masked_where(epoch_label == 0, epoch_label)
expected_output=np.where(epoch_com_arr[mm,:])
上面的代码片段产生
0 = {ndarray: (14,)} [0 0 0 0 1 1 1 1 2 2 2 3 3 3]
1 = {ndarray: (14,)} [0 1 2 3 0 1 2 3 0 2 3 0 2 3]
这不符合我的意图
或者
expected_output=epoch_com_arr[mm,:]
其中产生
0 = {ndarray: (4,)} [7 9 3 5]
1 = {ndarray: (4,)} [7 9 3 5]
2 = {ndarray: (4,)} [5 0 3 3]
3 = {ndarray: (4,)} [5 0 3 3]
我可以知道如何解决这个问题吗
和
In [242]: Nepochs = 4
...: epoch_com = [np.random.randint(10, size=4) for _ in range(Nepochs)]
...: epoch_com_arr=np.array(epoch_com)
...: epoch_label=np.random.randint(2, size=Nepochs)
...: mm=np.ma.masked_where(epoch_label == 0, epoch_label)
...: expected_output=np.where(epoch_com_arr[mm,:])
查看变量:
In [246]: epoch_com_arr # a (4,4) array
Out[246]:
array([[7, 1, 3, 3],
[5, 6, 7, 8],
[5, 6, 3, 8],
[3, 5, 1, 1]])
我不知道您为什么使用“0 = {ndarray: (4,)} [5 0 3 3]”样式的显示。 这不是正常的numpy
。
我认为制作masked_array
没有任何好处:
In [247]: epoch_label
Out[247]: array([0, 0, 1, 0])
In [248]: mm
Out[248]:
masked_array(data=[--, --, 1, --],
mask=[ True, True, False, True],
fill_value=999999)
而只是将 0/1 转换为布尔值。 通常,当我们谈论“屏蔽”时,我们的意思是使用布尔数组作为索引,而不是使用np.ma
。
In [249]: epoch_label.astype(bool)
Out[249]: array([False, False, True, False])
该布尔值可用于选择arr
行,或者“取消选择”它们:
In [250]: epoch_com_arr[epoch_label.astype(bool),:]
Out[250]: array([[5, 6, 3, 8]])
In [251]: epoch_com_arr[~epoch_label.astype(bool),:]
Out[251]:
array([[7, 1, 3, 3],
[5, 6, 7, 8],
[3, 5, 1, 1]])
我不认为np.where
在这里有用。 这给出了epoch_com_arr[mm,:]
非零项的索引,并且使用 np.ma` 数组进行索引是有问题的。
np.where
可用于将epoch_label
转换为索引:
In [252]: idx = np.nonzero(epoch_label) # aka np.where
In [253]: idx
Out[253]: (array([2]),)
In [254]: epoch_com_arr[idx,:]
Out[254]: array([[[5, 6, 3, 8]]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.