簡體   English   中英

如何在PyTorch中組合/堆疊張量並組合尺寸?

[英]How to combine/stack tensors and combine dimensions in PyTorch?

我需要將表示大小為[1,84,84]的表示灰度圖像的4個張量組合到表示形狀[4,84,84]的堆棧中,表示四個灰度圖像,每個圖像以張量樣式表示為“通道” CxWxH。

我正在使用PyTorch。

我已經嘗試過使用torch.stack和torch.cat,但是如果其中之一是解決方案,那么我沒有運氣找出正確的准備/方法來獲得結果。

感謝您的任何幫助。

import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        #t =  torch.cat(tuple(self.phi),dim=0)
        t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
            self.phi.append(torch.tensor([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

上面的代碼返回:

print(a.phi)

deque([tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84])], maxlen=4)

print(s, s.shape)

tensor([[ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84],
        [ 1, 84, 84]]) torch.Size([4, 3])

但是我想返回的就是[4,84,84]。 我懷疑這很簡單,但卻在逃避我。

似乎您誤解了torch.tensor([1, 84, 84])在做什么。 讓我們來看看:

torch.tensor([1, 84, 84])
print(x, x.shape) #tensor([ 1, 84, 84]) torch.Size([3])

您可以從上面的示例中看到,它為您提供了只有一個維度的張量。

從問題陳述中,您需要一個形張量[1,84,84]。 看起來像這樣:

from collections import deque
import torch
import torchvision.transforms as T

class ReplayBuffer:
    def __init__(self, buffersize, batchsize, framestack, device, nS):
        self.buffer = deque(maxlen=buffersize)
        self.phi = deque(maxlen=framestack)
        self.batchsize = batchsize
        self.device = device

        self._initialize_stack(nS)

    def get_stack(self):
        t =  torch.cat(tuple(self.phi),dim=0)
#         t =  torch.stack(tuple(self.phi),dim=0)
        return t

    def _initialize_stack(self, nS):
        while len(self.phi) < self.phi.maxlen:
#             self.phi.append(torch.tensor([1,nS[1], nS[2]]))
            self.phi.append(torch.zeros([1,nS[1], nS[2]]))

a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)

請注意, torch.cat給您提供形狀為[ torch.stack ]的張量,而torch.stack給您提供形狀為[ torch.stack ]的張量。 它們的區別可以在torch.stack()和torch.cat()函數之間有什么區別?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM