How to make LSTMClassifier Bidirectional?

Goal: make LSTM self.classifier() learn from bidirectional layers.

# ! = line of interest

Question: What changes to LSTMClassifier do I need to make, in order to have this LSTM work bidirectionally?

When passing bidirectional=True to self.lstm = nn.LSTM(...) , I get Traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-51-b94d572a1b68> in <module>()
     11     """.split()
---> 13 run_training(args)

3 frames
<ipython-input-8-bb0d8b014e32> in run_training(input)
     54     elif args.checkpointfile:
     55         file_path = os.path.join(args.traindir, args.checkpointfile)
---> 56         model = LSTMTaggerModel.load_from_checkpoint(file_path)
     57     else:
     58         model = LSTMTaggerModel(**vars(args), num_classes=dm.num_classes, class_map=dm.class_map)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    155         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
--> 157         model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
    158         return model

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
    204         # load the state_dict on the model automatically
--> 205         model.load_state_dict(checkpoint['state_dict'], strict=strict)
    207         return model

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for LSTMTaggerModel:
    Missing key(s) in state_dict: "model.lstm.weight_ih_l0_reverse", "model.lstm.weight_hh_l0_reverse", "model.lstm.bias_ih_l0_reverse", "model.lstm.bias_hh_l0_reverse".

I think the problem is with forward() . It learns from the last state of the LSTM neural network, by slicing:

tag_space = self.classifier(lstm_out[:,-1,:])

However, bidirectional changes the architecture and thus the output shape.

Do I need to sum up or concatenate the values of the 2 layers/ directions?

Working Code:

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)  # !
        self.classifier = nn.Linear(hidden_dim, self.num_labels[0])

    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
            return tuple(repackage_hidden(v) for v in h)

    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.classifier(lstm_out[:,-1,:])  # !
        logits = F.log_softmax(tag_space, dim=1)
        loss = None
        if labels:
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
        return loss, logits

It sounds like you're trying to load a pretrained model (which uses an unidirectional LSTM) into a model which has a bidirectional LSTM in its state dict. There are several things you can do here, as there are innate differences between your pretrained state dict and your bidirectional state dict:

  1. Definitely use model.load_state_dict(model_params,strict=False) (see this link ). This will stop the complaining when you use a model that's different to what you're trying to learn. It means that your forward pass will be pretrained but not your backward pass.
  2. If you do this ^ you will need to sum or otherwise condense the final time steps for the forward and backward case because the classifier will then have a different shape otherwise. strict=False though will ignore this, so only do this if you care about having a pretrained first layer in your classifier.
  3. If you don't want to do the above two, you can copy the weights for model.lstm.weight_ih_l0_reverse and other missing parameters from the forward direction in the state dict, as it's just a python dictionary. It is not ideal because obviously the forward and backward pass will learn different things, but will stop the error and be in a reasonably good initialisation space. You will still have the same error in two though where your LSTM output is twice as big as it was.

