简体   繁体   中英

What does [1,2] means in .mean([1,2]) for tensor?

I have a tensor with shape torch.Size([3, 224, 225]) . when I do tensor.mean([1,2]) I get tensor([0.6893, 0.5840, 0.4741]). What does [1,2] mean here?

Operations that aggregate along dimensions like min , max , mean , sum , etc. specify the dimension along which to aggregate. It is common to use these operations across every dimension (ie get the mean for the entire tensor) or a single dimension (ie torch.mean(dim = 2) or torch.mean(2) returns the mean of the 225 elements for each of 3 x 224 vectors.

Pytorch also allows these operations across a set of multiple dimensions, such as in your case. This means to take the mean of the 224 x 224 elements for each of the indices along the 0th (non-aggregated dimension). Likewise, if your original tensor shape was a.shape = torch.Size([3,224,10,225]) , a.mean([1,3]) would return a tensor of shape [3,10] .

The shape of your tensor is 3 across dimension 0, 224 across dimension 1 and 225 across dimension 2.

I would say that tensor.mean([1,2]) calculates the mean across dimension 1 as well as dimension 2. Thats why you are getting 3 values. Each plane spanned by dimension 1 and 2 of size 224x225 is reduced to a single value / scalar. Since there are 3 planes that are spanned by dimension 1 and 2 of size 224x225 you get 3 values back. Each value represents the mean of a whole plane with 224x225 values.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM