简体   繁体   English

是否有任何pytorch函数可以将张量的特定连续维度合二为一?

[英]Is there any pytorch function can combine the specific continuous dimensions of tensor into one?

Let's call the function I'm looking for " magic_combine ", which can combine the continuous dimensions of tensor I give to it.让我们将我正在寻找的函数称为“ magic_combine ”,它可以组合我给它的张量的连续维度。 For more specific, I want it to do the following thing:更具体地说,我希望它做以下事情:

a = torch.zeros(1, 2, 3, 4, 5, 6)  
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 
print(b.size()) # should be (1, 2, 60, 6)

I know that torch.view() can do the similar thing.我知道torch.view()可以做类似的事情。 But I'm just wondering if there is any more elegant way to achieve the goal?但我只是想知道是否有更优雅的方式来实现目标?

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])

Seems to me a bit simpler than the current accepted answer and doesn't go through a list constructor (3 times).在我看来,比当前接受的答案简单一点,并且没有经过list构造函数(3 次)。

I am not sure what you have in mind with "a more elegant way", but Tensor.view() has the advantage not to re-allocate data for the view (original tensor and view share the same data), making this operation quite light-weight.我不确定你对“更优雅的方式”有什么想法,但Tensor.view()的优点是不为视图重新分配数据(原始张量和视图共享相同的数据),使得这个操作相当轻的。

As mentioned by @UmangGupta, it is however rather straight-forward to wrap this function to achieve what you want, eg:正如@UmangGupta 所提到的,包装这个函数来实现你想要的东西是相当直接的,例如:

import torch

def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])

Also possible with torch einops .也可以使用torch einops

Github . GitHub

> pip install einops
from einops import rearrange

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')

There is a variant of flatten<\/code> that takes start_dim<\/code> and end_dim<\/code> parameters.有一个flatten<\/code>变体,它采用start_dim<\/code>和end_dim<\/code>参数。 You can call it in the same way as your magic_combine<\/code> (except that end_dim<\/code> is inclusive).您可以使用与您的magic_combine<\/code>相同的方式调用它(除了end_dim<\/code>包含在内)。

a = torch.zeros(1, 2, 3, 4, 5, 6)  
b = a.flatten(2, 4) # combine dimension 2, 3, 4 
print(b.size()) # should be (1, 2, 60, 6)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM