简体   繁体   中英

Torch sum subsets of tensor

if the tensor is of shape [20, 5] then I need to take 10 at a time and sum them, so result is [2,5].

eg:
shape[20,5] -> shape[2, 5] (sum 10 at a time)
shape[100, 20] -> shape[10,20] (sum 10 at a time)

Is there any faster/optimal way to do this?

eg:
[[1, 1], [1, 2], [3, 4], [1,2]] i want [[2, 3], [4, 6]] by taking sum of 2 rows.

I am not aware of any off the shelf solution for that.

If having the average is enough you can use nn.AvgPool1d https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html#avgpool1d :

import torch, torch.nn as nn

x = torch.rand(batch_size, channels, lenght)
pool = nn.AvgPool1D(kernel_size=10, stride=10)

avg = pool(x)

With this solution, just make sure you are averaging the correct dimension.

EDIT I just realized you can get the sum by modifying the last line with avg = pool(x) * kernel_size !

You can also just write your own function that does the summing for you:

import torch

def SumWindow(x, window_size, dim):
    input_split = torch.split(x, window_size, dim)
    input_sum = [v.sum(dim=dim), for v in input_split] # may be expensive if there are too many tensors
    out = torch.cat(inptu_sum, dim=dim)
    return dim

It is not completely clear, but I cannot use a comment for this, so.

For the first case you have:

t1 = torch.tensor([[1., 1.], [1., 2.], [3., 4.], [1.,2.]])
t1.shape #=> torch.Size([4, 2])
t1
tensor([[1., 1.],
        [1., 2.],
        [3., 4.],
        [1., 2.]])

To get the desired output you should reshape:

tr1 = t1.reshape([2, 2, 2])
res1 = torch.sum(tr1, axis = 1)
res1.shape #=> torch.Size([2, 2])
res1
tensor([[2., 3.],
        [4., 6.]])

Let's take a tensor with all one elements ( torch.ones ) for the second case.

t2 = torch.ones((20, 5))
t2.shape #=> torch.Size([20, 5])
t2
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

So, reshaping to get the required (?) result:

tr2 = tensor.reshape((10, 2, 5))
res2 = torch.sum(tr2, axis = 0)
res2.shape #=> torch.Size([2, 5])
res2
tensor([[10., 10., 10., 10., 10.],
        [10., 10., 10., 10., 10.]])

Is this what you are looking for?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM