繁体   English   中英

如何在Pytorch中填充3维张量?

[英]How does one padd a tensor of 3 dimensions in Pytorch?

我试图使用内置的padding函数,但是出于某种原因,它并没有为我填充内容。 这是我的可复制代码:

import torch

def padding_batched_embedding_seq():
    ## 3 sequences with embedding of size 300
    a = torch.ones(1, 4, 5) # 25 seq len (so 25 tokens)
    b = torch.ones(1, 3, 5) # 22 seq len (so 22 tokens)
    c = torch.ones(1, 2, 5) # 15 seq len (so 15 tokens)
    ##
    sequences = [a, b, c]
    batch = torch.nn.utils.rnn.pad_sequence(sequences)

if __name__ == '__main__':
    padding_batched_embedding_seq()

错误信息:

Traceback (most recent call last):
  File "padding.py", line 51, in <module>
    padding_batched_embedding_seq()
  File "padding.py", line 40, in padding_batched_embedding_seq
    batch = torch.nn.utils.rnn.pad_sequence(sequences)
  File "/Users/rene/miniconda3/envs/automl/lib/python3.7/site-packages/torch/nn/utils/rnn.py", line 376, in pad_sequence
    out_tensor[:length, i, ...] = tensor
RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1.  Target sizes: [1, 4, 5].  Tensor sizes: [3, 5]

任何想法?


交叉发布: https : //discuss.pytorch.org/t/how-does-one-padd-a-tensor-of-3-dimensions/51097

您可以使用torch.ones(2,5)代替,也可以使用torch.ones(2,...),其中...对于每个样本都是相同的尺寸。 RuntimeError:张量(4)的扩展大小必须与非单维度1上的现有大小(3)相匹配。目标大小:[1、4、5]。 张量大小:[3,5]表示它期望除第一个〜dim == 0以外的所有尺寸都相同,因为第一个是可变的seq长度,而其他则是相同的输入项。

来自doc https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html的示例为:

 >>> from torch.nn.utils.rnn import pad_sequence
    >>> a = torch.ones(25, 300)
    >>> b = torch.ones(22, 300)
    >>> c = torch.ones(15, 300)
    >>> pad_sequence([a, b, c]).size()

输出:torch.Size([25,3,300])

使用shape:(max_sequence len,batch_size,single_input),因为默认情况下batch_first = False,但是我更喜欢batch_first = True形状为torch.Size([3,25,300])。

填充仅表示填充零,直到它与最大序列len匹配。 作为RNN中的输入,您可能更喜欢不包含零输入的打包序列。

因此,在您的示例中,如果输入具有更多的暗淡效果,它将像

 a = torch.ones(4, 5, 10) # 5*10 2d input,  sequence of length 4 for them
    b = torch.ones(3, 5, 10) 
    c = torch.ones(2, 5, 10)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM