简体   繁体   中英

How to make LSTMClassifier Bidirectional?

Goal: make LSTM self.classifier() learn from bidirectional layers.

# ! = line of interest

Question: What changes to LSTMClassifier do I need to make, in order to have this LSTM work bidirectionally?


When passing bidirectional=True to self.lstm = nn.LSTM(...) , I get Traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-51-b94d572a1b68> in <module>()
     11     """.split()
     12 
---> 13 run_training(args)

3 frames
<ipython-input-8-bb0d8b014e32> in run_training(input)
     54     elif args.checkpointfile:
     55         file_path = os.path.join(args.traindir, args.checkpointfile)
---> 56         model = LSTMTaggerModel.load_from_checkpoint(file_path)
     57     else:
     58         model = LSTMTaggerModel(**vars(args), num_classes=dm.num_classes, class_map=dm.class_map)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
    155         checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
    156 
--> 157         model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
    158         return model
    159 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
    203 
    204         # load the state_dict on the model automatically
--> 205         model.load_state_dict(checkpoint['state_dict'], strict=strict)
    206 
    207         return model

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for LSTMTaggerModel:
    Missing key(s) in state_dict: "model.lstm.weight_ih_l0_reverse", "model.lstm.weight_hh_l0_reverse", "model.lstm.bias_ih_l0_reverse", "model.lstm.bias_hh_l0_reverse".

I think the problem is with forward() . It learns from the last state of the LSTM neural network, by slicing:

tag_space = self.classifier(lstm_out[:,-1,:])

However, bidirectional changes the architecture and thus the output shape.

Do I need to sum up or concatenate the values of the 2 layers/ directions?


Working Code:

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 
        num_classes, 
        batch_size=10,
        embedding_dim=100, 
        hidden_dim=50, 
        vocab_size=128):

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)  # !
        self.classifier = nn.Linear(hidden_dim, self.num_labels[0])


    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)


    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.classifier(lstm_out[:,-1,:])  # !
        logits = F.log_softmax(tag_space, dim=1)
        loss = None
        if labels:
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
        return loss, logits

It sounds like you're trying to load a pretrained model (which uses an unidirectional LSTM) into a model which has a bidirectional LSTM in its state dict. There are several things you can do here, as there are innate differences between your pretrained state dict and your bidirectional state dict:

  1. Definitely use model.load_state_dict(model_params,strict=False) (see this link ). This will stop the complaining when you use a model that's different to what you're trying to learn. It means that your forward pass will be pretrained but not your backward pass.
  2. If you do this ^ you will need to sum or otherwise condense the final time steps for the forward and backward case because the classifier will then have a different shape otherwise. strict=False though will ignore this, so only do this if you care about having a pretrained first layer in your classifier.
  3. If you don't want to do the above two, you can copy the weights for model.lstm.weight_ih_l0_reverse and other missing parameters from the forward direction in the state dict, as it's just a python dictionary. It is not ideal because obviously the forward and backward pass will learn different things, but will stop the error and be in a reasonably good initialisation space. You will still have the same error in two though where your LSTM output is twice as big as it was.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM