[英]How to combine/stack tensors and combine dimensions in PyTorch?
我需要将表示大小为[1,84,84]的表示灰度图像的4个张量组合到表示形状[4,84,84]的堆栈中,表示四个灰度图像,每个图像以张量样式表示为“通道” CxWxH。
我正在使用PyTorch。
我已经尝试过使用torch.stack和torch.cat,但是如果其中之一是解决方案,那么我没有运气找出正确的准备/方法来获得结果。
感谢您的任何帮助。
import torchvision.transforms as T
class ReplayBuffer:
def __init__(self, buffersize, batchsize, framestack, device, nS):
self.buffer = deque(maxlen=buffersize)
self.phi = deque(maxlen=framestack)
self.batchsize = batchsize
self.device = device
self._initialize_stack(nS)
def get_stack(self):
#t = torch.cat(tuple(self.phi),dim=0)
t = torch.stack(tuple(self.phi),dim=0)
return t
def _initialize_stack(self, nS):
while len(self.phi) < self.phi.maxlen:
self.phi.append(torch.tensor([1,nS[1], nS[2]]))
a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)
上面的代码返回:
print(a.phi)
deque([tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84]), tensor([ 1, 84, 84])], maxlen=4)
print(s, s.shape)
tensor([[ 1, 84, 84],
[ 1, 84, 84],
[ 1, 84, 84],
[ 1, 84, 84]]) torch.Size([4, 3])
但是我想返回的就是[4,84,84]。 我怀疑这很简单,但却在逃避我。
似乎您误解了torch.tensor([1, 84, 84])
在做什么。 让我们来看看:
torch.tensor([1, 84, 84])
print(x, x.shape) #tensor([ 1, 84, 84]) torch.Size([3])
您可以从上面的示例中看到,它为您提供了只有一个维度的张量。
从问题陈述中,您需要一个形张量[1,84,84]。 看起来像这样:
from collections import deque
import torch
import torchvision.transforms as T
class ReplayBuffer:
def __init__(self, buffersize, batchsize, framestack, device, nS):
self.buffer = deque(maxlen=buffersize)
self.phi = deque(maxlen=framestack)
self.batchsize = batchsize
self.device = device
self._initialize_stack(nS)
def get_stack(self):
t = torch.cat(tuple(self.phi),dim=0)
# t = torch.stack(tuple(self.phi),dim=0)
return t
def _initialize_stack(self, nS):
while len(self.phi) < self.phi.maxlen:
# self.phi.append(torch.tensor([1,nS[1], nS[2]]))
self.phi.append(torch.zeros([1,nS[1], nS[2]]))
a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84])
print(a.phi)
s = a.get_stack()
print(s, s.shape)
请注意, torch.cat
给您提供形状为[ torch.stack
]的张量,而torch.stack
给您提供形状为[ torch.stack
]的张量。 它们的区别可以在torch.stack()和torch.cat()函数之间有什么区别?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.