簡體   English   中英

如何表示 PyTorch LSTM 3D 張量?

[英]How do I represent a PyTorch LSTM 3D Tensor?

根據文檔,我看到Pytorch's LSTM expects all of its inputs to be 3D tensors. 我正在嘗試做一個簡單的序列到序列 LSTM,我有:

class BaselineLSTM(nn.Module):
    def __init__(self):
        super(BaselineLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=100, hidden_size=100)

    def forward(self, x):
        print('x', x)
        x = self.lstm(x)

        return x

我的x.size()torch.Size([100, 1]) 我希望以某種方式需要第三維,但我不確定它的實際含義。 任何幫助將不勝感激。

輸入形狀在這個Pytorch 文檔的Inputs: input, (h_0, c_0)部分有進一步的闡述。 輸入張量的第一個維度預計對應於序列長度,第二個維度對應批次大小,第三個維度對應輸入大小。

所以對於你的例子,輸入張量x實際上應該是大小(seq_length, batch_size, 100)

這是 Pytorch 論壇上的詳細主題以獲取更多詳細信息: https ://discuss.pytorch.org/t/why-3d-input-tensors-in-lstm/4455/9

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM