简体   繁体   中英

pytorch LSTM model not learning

I created a simple LSTM model to predict Uniqlo closing price. The problem is, my model doesn't seem to learn anything. This is the link to my notebook

This is my model creation class (I tried relu activation function before, get the same outcome):

class lstm(torch.nn.Module):
  def __init__(self,hidden_layers):
    super(lstm,self).__init__()
    self.hidden_layers = hidden_layers
    self.lstm = torch.nn.LSTM(input_size = 2,hidden_size = 100,num_layers = self.hidden_layers,batch_first=True)
    self.hidden1 = torch.nn.Linear(100,80)
    self.dropout1 = torch.nn.Dropout(0.1)
    self.hidden2 = torch.nn.Linear(80,60)
    self.dropout2 = torch.nn.Dropout(0.1)
    self.output = torch.nn.Linear(60,1)

  def forward(self,x):
    batch = len(x)
    h = torch.randn(self.hidden_layers,batch,100).requires_grad_().cuda()
    c = torch.randn(self.hidden_layers,batch,100).requires_grad_().cuda()

    x,(ho,co)= self.lstm(x.view(batch,10,2),(h.detach(),c.detach()))
    x = torch.reshape(x[:,-1,:],(batch,-1))
    x = self.hidden1(x)
    x = torch.nn.functional.tanh(x)
    x = self.dropout1(x)
    x = self.hidden2(x)
    x = torch.nn.functional.tanh(x)
    x = self.dropout2(x)
    x = self.output(x)
    return x

model = lstm(10)

This is my training loss plot: training loss

This is my validation loss plot: validation loss

This is my ground truth (blue) vs prediction (orange): ground truth vs prediction

Can anyone please point out what did I do wrongly?

You train scaler with whole data. This is not a good strategy. You should use only the training data.

You don't have to scale targets. Use it directly or apply log function or use returns.

About hidden state and cell memory, why do you kept track gradients and detach them just after? You don't have to detach hidden state and cell memory when feeding lstm layer because it participates of back propagation.

If I understand what you did, you predict the next close price using the last 10 open prices and volumes. I don't think you can get good results with this configuration. You should formalize the problem better.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM