简体   繁体   English

如何在LSTM中添加dropout层避免过拟合

[英]How to add a dropout layer in LSTM to avoid overfitting

While implementing a hybrid quantum LSTM model, the model is overfitting and thus giving low accuracy.在实现混合量子 LSTM 模型时,该模型过度拟合,因此精度较低。 I tried setting dropout = 1 in nn.LSTM but no improvement.我尝试在nn.LSTM设置nn.LSTM dropout = 1但没有改进。 I have used a single hidden layer.我使用了一个隐藏层。 How do I add the dropout layer to reduce overfitting?如何添加 dropout 层以减少过拟合?

Model parameters:型号参数:

input_dim = 16
hidden_dim = 100
layer_dim = 1
output_dim = 1

Model class:模型类:

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.layer_dim = layer_dim

        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, dropout=1, batch_first=True, )
      
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.hybrid = Hybrid(qiskit.Aer.get_backend('qasm_simulator'), 100, np.pi / 2)

    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        
        x, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
       
        x = self.fc(x[:, -1, :]) 
        x = self.hybrid(x)
        return T.cat((x, 1 - x), -1)    

Pytorch's LSTM layer takes the dropout parameter as the probability of the layer having its nodes zeroed out . Pytorch的LSTM层开出dropout参数为具有其节点清零该层的可能性 When you pass 1, it will zero out the whole layer.当您传递 1 时,它会将整个图层归零。 I assume you meant to make it a conventional value such as 0.3 or 0.5.我假设您打算将其设为常规值,例如 0.3 或 0.5。

As @ayandas says above, too, it applies dropout to each layer except the last (see the link above), so it won't work for a single-layer LSTM.正如@ayandas 上面所说,它也将 dropout 应用于除最后一层之外的每一层(请参阅上面的链接),因此它不适用于单层 LSTM。 You can always apply your own dropout usingnn.dropout at the output of your LSTM layers if you wish.如果您愿意,您始终可以在 LSTM 层的输出中使用nn.dropout应用您自己的dropout

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM