简体   繁体   English

遍历pytorch LSTM

[英]Looping over pytorch LSTM

I am training a seq2seq model for machine translation in pytorch. 我正在训练用于pytorch中机器翻译的seq2seq模型。 I would like to gather the cell state at every time step, while still having the flexibility of multiple layers and bidirectionality, that you can find in the LSTM module of pytorch, for example. 我想在每个时间步收集单元状态,同时仍具有多层和双向性的灵活性,例如,您可以在pytorch的LSTM模块中找到这种状态。

To this end, I have the following encoder and forward method, where I loop over the LSTM module. 为此,我具有以下编码器和转发方法,其中循环遍历LSTM模块。 The problem is, that the model does not train very well. 问题是,模型训练得不好。 Right after the loop terminates, you can see the normal way to use the LSTM module and with that, the model trains. 在循环终止之后,您可以立即看到使用LSTM模块的正常方法,然后模型就会进行训练。

So, is the loop not a valid way to do this? 那么,循环不是执行此操作的有效方法吗?

class encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.input_dim = input_dim
        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = nn.Embedding(input_dim, emb_dim)

        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):

        #src = [src sent len, batch size]

        embedded = self.dropout(self.embedding(src))

        #embedded = [src sent len, batch size, emb dim]
        hidden_all = []

        for i in range(len(embedded[:,1,1])):
            outputs, hidden = self.rnn(embedded[i,:,:].unsqueeze(0))
            hidden_all.append(hidden)



        #outputs, hidden = self.rnn(embedded)

        #outputs = [src sent len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        None
        #outputs are always from the top hidden layer

        return hidden

好的,修复非常简单,您只需在外部运行第一个时间步,即可将隐藏的元组输入LSTM模块。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM