I have a mask
, which has a shape of: [64, 2895]
and an array pred
which has a shape of [64, 2895, 161]
.
mask
is binary with only 0
s and 1
s. What I want to do is reduce pred
so that it maintains 64
batches, and along the 2895
, wherever there is a 1
in the mask
for each batch, return the related pred
.
So as a simplified example, if:
mask = [[1, 0, 0],
[1, 1, 0],
[0, 0, 1]]
pred = [[[0.12, 0.23, 0.45, 0.56, 0.57],
[0.91, 0.98, 0.97, 0.96, 0.95],
[0.24, 0.46, 0.68, 0.80, 0.15]],
[[1.12, 1.23, 1.45, 1.56, 1.57],
[1.91, 1.98, 1.97, 1.96, 1.95],
[1.24, 1.46, 1.68, 1.80, 1.15]],
[[2.12, 2.23, 2.45, 2.56, 2.57],
[2.91, 2.98, 2.97, 2.96, 2.95],
[2.24, 2.46, 2.68, 2.80, 2.15]]]
What I want is:
[[[0.12, 0.23, 0.45, 0.56, 0.57]],
[[1.12, 1.23, 1.45, 1.56, 1.57],
[1.91, 1.98, 1.97, 1.96, 1.95]],
[[2.24, 2.46, 2.68, 2.80, 2.15]]]
I realize that there are different dimensions, I hope that that's possible. If not, then fill in the missing dimensions with 0
. Either numpy
or pytorch
would be helpful. Thank you.
If you want a vectorized computation then different dimension seems not possible, but this would give you the one with masked entry filled with 0:
# pred: torch.size([64, 2895, 161])
# mask: torch.size([64, 2895])
result = pred * mask[:, :, None]
# extend mask with another dimension so now it can do entry-wise multiplication
and result
is exactly what you want
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.