簡體   English   中英

如何在批處理上填充零,PyTorch

[英]How to pad zeros on Batch, PyTorch

有一個更好的方法嗎? 如何在不創建新張量 object 的情況下用零填充張量? 我需要輸入始終具有相同的batchsize大小,所以我想用零填充小於batchsize大小的輸入。 就像序列長度較短時在 NLP 中填充零一樣,但這是批處理的填充。

目前,我創建了一個新張量,但正因為如此,我的 GPU 將從 memory 中取出 go。 我不想將批處理大小減少一半來處理此操作。

import torch
from torch import nn

class MyModel(nn.Module):
    def __init__(self, batchsize=16):
        super().__init__()
        self.batchsize = batchsize
    
    def forward(self, x):
        b, d = x.shape
        
        print(x.shape) # torch.Size([7, 32])

        if b != self.batchsize: # 2. I need batches to be of size 16, if batch isn't 16, I want to pad the rest to zero
            new_x = torch.zeros(self.batchsize,d) # 3. so I create a new tensor, but this is bad as it increase the GPU memory required greatly
            new_x[0:b,:] = x
            x = new_x
            b = self.batchsize
        
        print(x.shape) # torch.Size([16, 32])

        return x

model = MyModel()
x = torch.randn((7, 32)) # 1. shape's batch is 7, because this is last batch, and I dont want to "drop_last"
y = model(x)
print(y.shape)

您可以像這樣填充額外的元素:

import torch.nn.functional as F

n = self.batchsize - b

new_x = F.pad(x, (0,0,n,0)) # pad the start of 2d tensors
new_x = F.pad(x, (0,0,0,n)) # pad the end of 2d tensors
new_x = F.pad(x, (0,0,0,0,0,n)) # pad the end of 3d tensors

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM