簡體   English   中英

是否有任何pytorch函數可以將張量的特定連續維度合二為一?

[英]Is there any pytorch function can combine the specific continuous dimensions of tensor into one?

讓我們將我正在尋找的函數稱為“ magic_combine ”,它可以組合我給它的張量的連續維度。 更具體地說,我希望它做以下事情:

a = torch.zeros(1, 2, 3, 4, 5, 6)  
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 
print(b.size()) # should be (1, 2, 60, 6)

我知道torch.view()可以做類似的事情。 但我只是想知道是否有更優雅的方式來實現目標?

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])

在我看來,比當前接受的答案簡單一點,並且沒有經過list構造函數(3 次)。

我不確定你對“更優雅的方式”有什么想法,但Tensor.view()的優點是不為視圖重新分配數據(原始張量和視圖共享相同的數據),使得這個操作相當輕的。

正如@UmangGupta 所提到的,包裝這個函數來實現你想要的東西是相當直接的,例如:

import torch

def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])

也可以使用torch einops

GitHub

> pip install einops
from einops import rearrange

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')

有一個flatten<\/code>變體,它采用start_dim<\/code>和end_dim<\/code>參數。 您可以使用與您的magic_combine<\/code>相同的方式調用它(除了end_dim<\/code>包含在內)。

a = torch.zeros(1, 2, 3, 4, 5, 6)  
b = a.flatten(2, 4) # combine dimension 2, 3, 4 
print(b.size()) # should be (1, 2, 60, 6)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM