I have a tensor with shape torch.Size([3, 224, 225])
. when I do tensor.mean([1,2])
I get tensor([0.6893, 0.5840, 0.4741]). What does [1,2] mean here?
Operations that aggregate along dimensions like min
, max
, mean
, sum
, etc. specify the dimension along which to aggregate. It is common to use these operations across every dimension (ie get the mean for the entire tensor) or a single dimension (ie torch.mean(dim = 2)
or torch.mean(2)
returns the mean of the 225 elements for each of 3 x 224 vectors.
Pytorch also allows these operations across a set of multiple dimensions, such as in your case. This means to take the mean of the 224 x 224 elements for each of the indices along the 0th (non-aggregated dimension). Likewise, if your original tensor shape was a.shape = torch.Size([3,224,10,225])
, a.mean([1,3]) would return a tensor of shape [3,10]
.
The shape of your tensor is 3 across dimension 0, 224 across dimension 1 and 225 across dimension 2.
I would say that tensor.mean([1,2])
calculates the mean across dimension 1 as well as dimension 2. Thats why you are getting 3 values. Each plane spanned by dimension 1 and 2 of size 224x225 is reduced to a single value / scalar. Since there are 3 planes that are spanned by dimension 1 and 2 of size 224x225 you get 3 values back. Each value represents the mean of a whole plane with 224x225 values.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.