简体   繁体   中英

Speeding up a pytorch tensor operation

I am trying to speed up the below operation by doing some sort of matrix/vector-multiplication, can anyone see a nice quick solution? It should also work for a special case where a tensor has shape 0 (torch.Size([])) but i am not able to initialize such a tensor. See the image below for the type of tensor i am referring to: tensor to add to test

def adstock_geometric(x: torch.Tensor, theta: float):
    x_decayed = torch.zeros_like(x)
    x_decayed[0] = x[0]

    for xi in range(1, len(x_decayed)):
        x_decayed[xi] = x[xi] + theta * x_decayed[xi - 1]

    return x_decayed

def adstock_multiple_samples(x: torch.Tensor, theta: torch.Tensor):

    listtheta = theta.tolist()
    if isinstance(listtheta, float):
        return adstock_geometric(x=x,
                                 theta=theta)
    x_decayed = torch.zeros((100, 112, 1))
    for idx, theta_ in enumerate(listtheta):
        x_decayed_one_entry = adstock_geometric(x=x,
                                                theta=theta_)
        x_decayed[idx] = x_decayed_one_entry
    return x_decayed

if __name__ == '__main__':
    ones = torch.tensor([1])
    hundreds = torch.tensor([idx for idx in range(100)])
    x = torch.tensor([[idx] for idx in range(112)])
    ones = adstock_multiple_samples(x=x,
                                    theta=ones)
    hundreds = adstock_multiple_samples(x=x,
                                        theta=hundreds)
    print(ones)
    print(hundreds)

I came up with the following, which is 40 times faster on your example:

import torch

def adstock_multiple_samples(x: torch.Tensor, theta: torch.Tensor):
    arange = torch.arange(len(x))
    powers = (arange[:, None] - arange).clip(0)
    return ((theta[:, None, None] ** powers[None, :, :]).tril() * x).sum(-1)

It behaves as expected:

>>> x = torch.arange(112)
>>> theta = torch.arange(100)
>>> adstock_multiple_samples(x, theta)
... # the same output

Note that I considered that x was a 1D-tensor, as for your example the second dimension was not needed.

It also works with theta = torch.empty((0,)) , and it returns an empty tensor.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM