簡體   English   中英

PyTorch 重塑張量維度

[英]PyTorch reshape tensor dimension

我想將形狀(5,)的向量重塑為形狀(1, 5)的矩陣。

使用 numpy,我可以做到:

>>> import numpy as np
>>> a = np.array([1, 2, 3, 4, 5])
>>> a.shape
(5,)
>>> a = np.reshape(a, (1, 5))
>>> a.shape
(1, 5)
>>> a
array([[1, 2, 3, 4, 5]])

但是我如何使用 PyTorch 做到這一點?

使用torch.unsqueeze(input, dim, out=None)

>>> import torch
>>> a = torch.Tensor([1, 2, 3, 4, 5])
>>> a

 1
 2
 3
 4
 5
[torch.FloatTensor of size 5]

>>> a = a.unsqueeze(0)
>>> a

 1  2  3  4  5
[torch.FloatTensor of size 1x5]

你可能會使用

a.view(1,5)
Out: 

 1  2  3  4  5
[torch.FloatTensor of size 1x5]

有多種方法可以重塑 PyTorch 張量。 您可以將這些方法應用於任何維度的張量。

讓我們從一個二維2 x 3張量開始:

x = torch.Tensor(2, 3)
print(x.shape)
# torch.Size([2, 3])

為了給這個問題增加一些魯棒性,讓我們通過在前面添加一個新維度並在中間添加另一個維度來重塑2 x 3張量,從而產生一個1 x 2 x 1 x 3張量。

方法 1:使用None添加維度

使用 NumPy 樣式None插入(又名np.newaxis )在任何你想要的地方添加維度 這里

print(x.shape)
# torch.Size([2, 3])

y = x[None, :, None, :] # Add new dimensions at positions 0 and 2.
print(y.shape)
# torch.Size([1, 2, 1, 3])

方法2:解壓

使用torch.Tensor.unsqueeze(i) (又名torch.unsqueeze(tensor, i)或就地版本unsqueeze_() )在第 i 個維度添加一個新維度。 返回的張量與原始張量共享相同的數據。 在這個例子中,我們可以使用unqueeze()兩次來添加兩個新維度。

print(x.shape)
# torch.Size([2, 3])

# Use unsqueeze twice.
y = x.unsqueeze(0) # Add new dimension at position 0
print(y.shape)
# torch.Size([1, 2, 3])

y = y.unsqueeze(2) # Add new dimension at position 2
print(y.shape)
# torch.Size([1, 2, 1, 3])

在 PyTorch 的實踐中, 為批處理添加額外的維度可能很重要,因此您可能經常會看到unsqueeze(0)

方法三:查看

使用torch.Tensor.view(*shape)指定所有尺寸。 返回的張量與原始張量共享相同的數據。

print(x.shape)
# torch.Size([2, 3])

y = x.view(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])

方法四:重塑

使用torch.Tensor.reshape(*shape) (又名torch.reshape(tensor, shapetuple) )指定所有維度。 如果原始數據是連續的並且具有相同的步幅,則返回的張量將是輸入的視圖(共享相同的數據),否則將是副本。 此函數類似於 NumPy reshape()函數,因為它允許您定義所有維度並可以返回視圖或副本。

print(x.shape)
# torch.Size([2, 3])

y = x.reshape(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])

此外,作者在 O'Reilly 2019 年出版的Programming PyTorch for Deep Learning中寫道:

現在您可能想知道view()reshape()之間有什么區別。 答案是view()作為原始張量上的視圖運行,因此如果基礎數據發生更改,視圖也會更改(反之亦然)。 但是,如果所需的視圖不連續, view()可能會拋出錯誤; 也就是說,如果從頭開始創建所需形狀的新張量,它不會共享相同的內存塊。 如果發生這種情況,您必須先調用tensor.contiguous()才能使用view() 但是, reshape()會在幕后完成所有這些工作,所以總的來說,我建議使用reshape()而不是view()

方法5:resize_

使用就地函數torch.Tensor.resize_(*sizes)修改原始張量。 該文檔指出:

警告。 這是一種低級方法。 存儲被重新解釋為 C 連續,忽略當前步幅(除非目標大小等於當前大小,在這種情況下張量保持不變)。 在大多數情況下,您將改為使用view()來檢查連續性,或者reshape()來在需要時復制數據。 要使用自定義步幅就地更改大小,請參閱set_()

print(x.shape)
# torch.Size([2, 3])

x.resize_(1, 2, 1, 3)
print(x.shape)
# torch.Size([1, 2, 1, 3])

我的觀察

如果您只想添加一個維度(例如為批次添加第 0 個維度),請使用unsqueeze(0) 如果您想完全改變維度,請使用reshape()

也可以看看:

pytorch中的reshape和view有什么區別?

view() 和 unsqueeze() 有什么區別?

在 PyTorch 0.4 中,是否建議在可能的情況下使用reshape而不是view

對於張量形狀的就地修改,您應該使用tensor.resize_()

In [23]: a = torch.Tensor([1, 2, 3, 4, 5])

In [24]: a.shape
Out[24]: torch.Size([5])


# tensor.resize_((`new_shape`))    
In [25]: a.resize_((1,5))
Out[25]: 

 1  2  3  4  5
[torch.FloatTensor of size 1x5]

In [26]: a.shape
Out[26]: torch.Size([1, 5])

在 PyTorch 中,如果操作末尾有下划​​線(如tensor.resize_() ),則該操作會對原始張量in-place修改。


此外,您可以簡單地在火炬張量中使用np.newaxis來增加維度。 這是一個例子:

In [34]: list_ = range(5)
In [35]: a = torch.Tensor(list_)
In [36]: a.shape
Out[36]: torch.Size([5])

In [37]: new_a = a[np.newaxis, :]
In [38]: new_a.shape
Out[38]: torch.Size([1, 5])

或者你可以使用它,'-1' 意味着你不必指定元素的數量。

In [3]: a.view(1,-1)
Out[3]:

 1  2  3  4  5
[torch.FloatTensor of size 1x5]

這個問題已經得到了徹底的回答,但是我想為經驗不足的 python 開發人員補充一點,您可能會發現*運算符與view()結合使用很有幫助。

例如,如果您有一個特定的張量大小,您希望不同的數據張量符合,您可以嘗試:

img = Variable(tensor.randn(20,30,3)) # tensor with goal shape
flat_size = 20*30*3
X = Variable(tensor.randn(50, flat_size)) # data tensor

X = X.view(-1, *img.size()) # sweet maneuver
print(X.size()) # size is (50, 20, 30, 3)

這也適用於 numpy shape

img = np.random.randn(20,30,3)
flat_size = 20*30*3
X = Variable(tensor.randn(50, flat_size))
X = X.view(-1, *img.shape)
print(X.size()) # size is (50, 20, 30, 3)

torch.reshape()用於欺騙numpy reshape方法。

它出現在view()torch.resize_()之后,它位於dir(torch)包中。

import torch
x=torch.arange(24)
print(x, x.shape)
x_view = x.view(1,2,3,4) # works on is_contiguous() tensor
print(x_view.shape)
x_reshaped = x.reshape(1,2,3,4) # works on any tensor
print(x_reshaped.shape)
x_reshaped2 = torch.reshape(x_reshaped, (-1,)) # part of torch package, while view() and resize_() are not
print(x_reshaped2.shape)

出去:

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23]) torch.Size([24])
torch.Size([1, 2, 3, 4])
torch.Size([1, 2, 3, 4])
torch.Size([24])

但是你知道它也可以作為squeeze()unsqueeze()的替代品嗎?

x = torch.tensor([1, 2, 3, 4])
print(x.shape)
x1 = torch.unsqueeze(x, 0)
print(x1.shape)
x2 = torch.unsqueeze(x1, 1)
print(x2.shape)
x3=x.reshape(1,1,4)
print(x3.shape)
x4=x.reshape(4)
print(x4.shape)
x5=x3.squeeze()
print(x5.shape)

出去:

torch.Size([4])
torch.Size([1, 4])
torch.Size([1, 1, 4])
torch.Size([1, 1, 4])
torch.Size([4])
torch.Size([4])
import torch
>>>a = torch.Tensor([1,2,3,4,5])
>>>a.size()
torch.Size([5])
#use view to reshape

>>>b = a.view(1,a.shape[0])
>>>b
tensor([[1., 2., 3., 4., 5.]])
>>>b.size()
torch.Size([1, 5])
>>>b.type()
'torch.FloatTensor'

據我所知,重塑張量的最佳方法是使用einops 它通過提供簡單而優雅的功能解決了各種重塑問題。 在您的情況下,代碼可以寫成

from einops import rearrange
ans = rearrange(tensor,'h -> 1 h')

我強烈建議您嘗試一下。

順便說一句,您可以將它與 pytorch/tensorflow/numpy 和許多其他庫一起使用。

假設以下代碼:

import torch
import numpy as np
a = torch.tensor([1, 2, 3, 4, 5])

以下三個調用具有完全相同的效果:

res_1 = a.unsqueeze(0)
res_2 = a.view(1, 5)
res_3 = a[np.newaxis,:]
res_1.shape == res_2.shape == res_3.shape == (1,5)  # Returns true

請注意,對於任何生成的張量,如果您修改其中的數據,您也在修改 a 中的數據,因為它們沒有數據的副本,而是引用 a 中的原始數據。

res_1[0,0] = 2
a[0] == res_1[0,0] == 2  # Returns true

另一種方法是使用resize_ in place 操作:

a.shape == res_1.shape  # Returns false
a.reshape_((1, 5))
a.shape == res_1.shape # Returns true

小心使用resize_或其他就地操作autograd 請參閱以下討論: https ://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

import torch
t = torch.ones((2, 3, 4))
t.size()
>>torch.Size([2, 3, 4])
a = t.view(-1,t.size()[1]*t.size()[2])
a.size()
>>torch.Size([2, 12])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM