[英]How to pad zeros on Batch, PyTorch
有一個更好的方法嗎? 如何在不創建新張量 object 的情況下用零填充張量? 我需要輸入始終具有相同的batchsize
大小,所以我想用零填充小於batchsize
大小的輸入。 就像序列長度較短時在 NLP 中填充零一樣,但這是批處理的填充。
目前,我創建了一個新張量,但正因為如此,我的 GPU 將從 memory 中取出 go。 我不想將批處理大小減少一半來處理此操作。
import torch
from torch import nn
class MyModel(nn.Module):
def __init__(self, batchsize=16):
super().__init__()
self.batchsize = batchsize
def forward(self, x):
b, d = x.shape
print(x.shape) # torch.Size([7, 32])
if b != self.batchsize: # 2. I need batches to be of size 16, if batch isn't 16, I want to pad the rest to zero
new_x = torch.zeros(self.batchsize,d) # 3. so I create a new tensor, but this is bad as it increase the GPU memory required greatly
new_x[0:b,:] = x
x = new_x
b = self.batchsize
print(x.shape) # torch.Size([16, 32])
return x
model = MyModel()
x = torch.randn((7, 32)) # 1. shape's batch is 7, because this is last batch, and I dont want to "drop_last"
y = model(x)
print(y.shape)
您可以像這樣填充額外的元素:
import torch.nn.functional as F
n = self.batchsize - b
new_x = F.pad(x, (0,0,n,0)) # pad the start of 2d tensors
new_x = F.pad(x, (0,0,0,n)) # pad the end of 2d tensors
new_x = F.pad(x, (0,0,0,0,0,n)) # pad the end of 3d tensors
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.