[英]How do I do masking in PyTorch / Numpy with different dimensions?
I have a mask
with a size of torch.Size([20, 1, 199])
and a tensor, reconstruct_output
and inputs
both with a size of torch.Size([20, 1, 161, 199])
.我有一个尺寸为torch.Size([20, 1, 199])
的mask
和一个张量、 reconstruct_output
和inputs
,它们的尺寸都为torch.Size([20, 1, 161, 199])
。
I want to set reconstruct_output
to inputs
where the mask
is 0
.我想将reconstruct_output
设置为mask
为0
的inputs
。 I tried:我试过了:
reconstruct_output[mask == 0] = inputs[mask == 0]
But I get an error:但我收到一个错误:
IndexError: The shape of the mask [20, 1, 199] at index 2 does not match the shape of the indexed tensor [20, 1, 161, 199] at index 2
We can use advanced indexing
here.我们可以在这里使用advanced indexing
。 To obtain the indexing arrays which we want to use to index both reconstruct_output
and inputs
, we need the indices along its axes where m==0
.为了获得索引 arrays ,我们想用它来索引reconstruct_output
和inputs
,我们需要沿着m==0
的轴的索引。 For that we can use np.where
, and use the resulting indices to update reconstruct_output
as:为此,我们可以使用np.where
,并使用生成的索引将reconstruct_output
更新为:
m = mask == 0
i, _, l = np.where(m)
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Here's a small example which I've checked with:这是我检查过的一个小例子:
mask = np.random.randint(0,3, (2, 1, 4))
reconstruct_output = np.random.randint(0,10, (2, 1, 3, 4))
inputs = np.random.randint(0,10, (2, 1, 3, 4))
Giving for instance:举个例子:
print(reconstruct_output)
array([[[[8, 9, 7, 2],
[5, 4, 6, 1],
[1, 4, 0, 3]]],
[[[4, 3, 3, 4],
[0, 9, 9, 7],
[3, 4, 9, 3]]]])
print(inputs)
array([[[[7, 3, 9, 8],
[3, 1, 0, 8],
[0, 5, 4, 8]]],
[[[3, 7, 5, 8],
[2, 5, 3, 8],
[3, 6, 7, 5]]]])
And the mask
:和mask
:
print(mask)
array([[[0, 1, 2, 1]],
[[1, 0, 1, 0]]])
By using np.where
to find the indices where there are zeroes in mask
we get:通过使用np.where
找到mask
中为零的索引,我们得到:
m = mask == 0
i, _, l = np.where(m)
i
# array([0, 1, 1])
l
# array([0, 1, 3])
Hence we'll be replacing the 0th column from the first 2D array and the 1st and 3rd from the second 2D array.因此,我们将替换第一个二维数组的第 0 列以及第二个二维数组的第 1 列和第 3 列。
We can now use these arrays to replace along the corresponding axes indexing as:我们现在可以使用这些 arrays 沿相应的轴索引替换:
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Getting:得到:
reconstruct_output
array([[[[7, 9, 7, 2],
[3, 4, 6, 1],
[0, 4, 0, 3]]],
[[[4, 7, 3, 8],
[0, 5, 9, 8],
[3, 6, 9, 5]]]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.