When I have a tensor m
of shape [12, 10]
and a vector s
of scalars with shape [12]
, how can I multiply each row of m
with the corresponding scalar in s
?
You need to add a corresponding singleton dimension:
m * s[:, None]
s[:, None]
has size of (12, 1)
when multiplying a (12, 10)
tensor by a (12, 1)
tensor pytorch knows to broadcast s
along the second singleton dimension and perform the "element-wise" product correctly.
You can broadcast a vector to a higher dimensional tensor like so :
def row_mult(input, vector):
extra_dims = (1,)*(input.dim()-1)
return t * vector.view(-1, *extra_dims)
A slighty hard to understand at first, but very powerful technique is to use Einstein summation:
torch.einsum('i,ij->ij', s, m)
Shai's answer works if you know the number of dimensions in advance and can hardcode the correct number of None
's. This can be extended to extra dimentions is required:
mask = (torch.rand(12) > 0.5).int()
data = (torch.rand(12, 2, 3, 4))
result = data * mask[:,None,None,None]
result.shape # torch.Size([12, 2, 3, 4])
mask[:,None,None,None].shape # torch.Size([12, 1, 1, 1])
If you are dealing with data of variable or unknown dimensions, then it may require manually extending mask
to the correct shape
mask = (torch.rand(12) > 0.5).int()
while mask.dim() < data.dim(): mask.unsqueeze_(1)
result = data * mask
result.shape # torch.Size([12, 2, 3, 4])
mask.shape # torch.Size([12, 1, 1, 1])
This is a bit of an ugly solution, but it does work. There is probably a much more elegant way to correctly reshape the mask
tensor inline for a variable number of dimensions
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.