简体   繁体   中英

Pytorch tensor broadcasting along axis

a=torch.rand(20)
b=torch.rand(20, 20)
a+b # Works!

a=torch.rand(32, 20)
b=torch.rand(32, 20, 20)
a+b # Doesn't work!

Does anyone know how broadcasting in the first example could be generalized to the second example along axis 0 with no for loops?

I tried normal addition but broadcasting in Pytorch doesn't seem to work this way!

The dimensions in the second case are incompatible. You need to insert a unitary dimension into a to achieve the same results as the first case: a.unsqueeze(1) + b .


PyTorch follows the same broadcasting rules as NumPy. See https://numpy.org/doc/stable/user/basics.broadcasting.html

See specifically the first paragraph of the General Broadcasting Rules section.

When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing (ie rightmost) dimensions and works its way left . Two dimensions are compatible when

  1. they are equal, or

  2. one of them is 1

Furthermore

Arrays do not need to have the same number of dimensions. For example, if you have a 256x256x3 array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values.

This is effectively saying that if we line up the shapes starting from the right , and then insert ones in any blank spots all the dimensions should be compatible.

Considering the first case. If we line up the tensor shapes starting from the right we have

a:       20
b:  20 x 20

and insert one into the missing spot

a:  1 x 20
b: 20 x 20

we see that the shapes are compatible because the first dimension has a 1 and the second dimension has both values equal. The output shape of the broadcasted operation is 20 x 20 , taking the first 20 from the first dimension of b .

Considering the second case, if we try to do the same

a:      32 x 20
b: 32 x 20 x 20

after inserting one into the missing spot we have

a:  1 x 32 x 20
b: 32 x 20 x 20

!!! These shapes are incompatible since the second dimension of a is 32 and the second dimension of b is 20 (since 32 != 20 and neither is equal to 1 ).


For the second example, one way you could make these shapes compatible would be to reshape a so that it has shape 32 x 1 x 20 . Ie insert an explicit unitary dimension in the middle. This could be done with any of the three methods.

a.reshape(32, 1, 20)+b

or equivalently

a.unsqueeze(1)+b

or equivalently

a[:, None, :]+b

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM