简体   繁体   English

nn.LSTM 似乎没有学到任何东西或没有正确更新

[英]nn.LSTM doesn't seem to learn anything or not updating properly

I was trying out a simple LSTM use case form pytorch, with the following model.我正在尝试一个简单的 LSTM 用例形式 pytorch,以及以下 model。

class SimpleLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(batch_first=True, input_size=embedding_dim, num_layers=1, hidden_size=hidden_dim, bidirectional=True)
        self.linear = nn.Linear(hidden_dim*2, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):   # NxD, padded to same length with 0s in N-sized batch
        x = self.embedding(x)
        output, (final_hidden_state, final_cell_state) = self.lstm(x)
        x = self.linear(output[:,-1,:])
        x=self.sigmoid(x)
        return x

It is a binary classification, with BCELoss (combined with the Sigmoid output layer).它是一个二进制分类,具有 BCELoss(结合 Sigmoid output 层)。 Unfortunately, loss is stuck at 0.6969 (ie it is not learning anything).不幸的是,损失停留在 0.6969(即它没有学到任何东西)。

I've tried using final_hidden_state , output[:,0,:] feeding into the linear layer, but so far no dice.我试过使用final_hidden_stateoutput[:,0,:]馈入线性层,但到目前为止没有骰子。

Everything else (optimizer, loss criterion, train loop, val loop) already works because I tried the exact same setup with a basic NN using nn.Embedding, nn.Linear, and nn.Sigmoid only, and could get to good loss decrease and high accuracy.其他所有东西(优化器、损失准则、训练循环、val 循环)都已经工作了,因为我尝试了完全相同的设置,只使用 nn.Embedding、nn.Linear 和 nn.Sigmoid 的基本神经网络,并且可以很好地减少损失和高精确度。 In the SimpleLSTM , the only thing I added is the nn.LSTM.SimpleLSTM中,我唯一添加的是 nn.LSTM。

  • Typically final_hidden_state is passed to linear, not output .通常final_hidden_state传递给线性,而不是output Use it.用它。
  • add 1-2 more linear layers after the LSTM.在 LSTM 之后再添加 1-2 个线性层。
  • try lower LR, especially when embeddings are not pre-trained.尝试降低 LR,尤其是当嵌入未经过预训练时。
  • Better yet, try loading pre-trained embeddings.更好的是,尝试加载预训练的嵌入。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM