[英]Is there any pytorch function can combine the specific continuous dimensions of tensor into one?
Let's call the function I'm looking for " magic_combine
", which can combine the continuous dimensions of tensor I give to it.让我们将我正在寻找的函数称为“
magic_combine
”,它可以组合我给它的张量的连续维度。 For more specific, I want it to do the following thing:更具体地说,我希望它做以下事情:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
I know that torch.view()
can do the similar thing.我知道
torch.view()
可以做类似的事情。 But I'm just wondering if there is any more elegant way to achieve the goal?但我只是想知道是否有更优雅的方式来实现目标?
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])
Seems to me a bit simpler than the current accepted answer and doesn't go through a list
constructor (3 times).在我看来,比当前接受的答案简单一点,并且没有经过
list
构造函数(3 次)。
I am not sure what you have in mind with "a more elegant way", but Tensor.view()
has the advantage not to re-allocate data for the view (original tensor and view share the same data), making this operation quite light-weight.我不确定你对“更优雅的方式”有什么想法,但
Tensor.view()
的优点是不为视图重新分配数据(原始张量和视图共享相同的数据),使得这个操作相当轻的。
As mentioned by @UmangGupta, it is however rather straight-forward to wrap this function to achieve what you want, eg:正如@UmangGupta 所提到的,包装这个函数来实现你想要的东西是相当直接的,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
Also possible with torch einops .也可以使用torch einops 。
> pip install einops
from einops import rearrange
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')
There is a variant of flatten<\/code> that takes
start_dim<\/code> and
end_dim<\/code> parameters.
有一个
flatten<\/code>变体,它采用
start_dim<\/code>和
end_dim<\/code>参数。
You can call it in the same way as your
magic_combine<\/code> (except that
end_dim<\/code> is inclusive).
您可以使用与您的
magic_combine<\/code>相同的方式调用它(除了
end_dim<\/code>包含在内)。
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.flatten(2, 4) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.