[英]Speeding up a pytorch tensor operation
I am trying to speed up the below operation by doing some sort of matrix/vector-multiplication, can anyone see a nice quick solution?我试图通过做某种矩阵/向量乘法来加速下面的操作,有人能看到一个很好的快速解决方案吗? It should also work for a special case where a tensor has shape 0 (torch.Size([])) but i am not able to initialize such a tensor.
它也应该适用于张量的形状为 0 (torch.Size([])) 但我无法初始化这样的张量的特殊情况。 See the image below for the type of tensor i am referring to: tensor to add to test
有关我所指的张量类型,请参见下图:要添加到测试的张量
def adstock_geometric(x: torch.Tensor, theta: float):
x_decayed = torch.zeros_like(x)
x_decayed[0] = x[0]
for xi in range(1, len(x_decayed)):
x_decayed[xi] = x[xi] + theta * x_decayed[xi - 1]
return x_decayed
def adstock_multiple_samples(x: torch.Tensor, theta: torch.Tensor):
listtheta = theta.tolist()
if isinstance(listtheta, float):
return adstock_geometric(x=x,
theta=theta)
x_decayed = torch.zeros((100, 112, 1))
for idx, theta_ in enumerate(listtheta):
x_decayed_one_entry = adstock_geometric(x=x,
theta=theta_)
x_decayed[idx] = x_decayed_one_entry
return x_decayed
if __name__ == '__main__':
ones = torch.tensor([1])
hundreds = torch.tensor([idx for idx in range(100)])
x = torch.tensor([[idx] for idx in range(112)])
ones = adstock_multiple_samples(x=x,
theta=ones)
hundreds = adstock_multiple_samples(x=x,
theta=hundreds)
print(ones)
print(hundreds)
I came up with the following, which is 40 times faster on your example:我想出了以下内容,在您的示例中速度提高了 40 倍:
import torch
def adstock_multiple_samples(x: torch.Tensor, theta: torch.Tensor):
arange = torch.arange(len(x))
powers = (arange[:, None] - arange).clip(0)
return ((theta[:, None, None] ** powers[None, :, :]).tril() * x).sum(-1)
It behaves as expected:它的行为符合预期:
>>> x = torch.arange(112)
>>> theta = torch.arange(100)
>>> adstock_multiple_samples(x, theta)
... # the same output
Note that I considered that x
was a 1D-tensor, as for your example the second dimension was not needed.请注意,我认为
x
是一维张量,因为对于您的示例,不需要第二维。
It also works with theta = torch.empty((0,))
, and it returns an empty tensor.它也适用于
theta = torch.empty((0,))
,它返回一个空张量。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.