简体   繁体   中英

pytorch RNN loss does not decrease and validate accuracy remains unchanged

I'm training a model using Pytorch GRU on a text-classification task (output dimension is 5). My network is implemented like the codes below.

class GRU(nn.Module):

    def __init__(self, model_param: ModelParam):
        super(GRU, self).__init__()

        self.embedding = nn.Embedding(model_param.vocab_size, model_param.embed_dim)

        # Build with pre-trained embedding vectors, if given.
        if model_param.vocab_embedding is not None:
            self.embedding.weight.data.copy_(model_param.vocab_embedding)
            self.embedding.weight.requires_grad = False

        self.rnn = nn.GRU(model_param.embed_dim,
                          model_param.hidden_dim,
                          num_layers=2,
                          bias=True,
                          batch_first=True,
                          dropout=0.5,
                          bidirectional=False)

        self.dropout = nn.Dropout(0.5)

        self.fc = nn.Sequential(
            nn.Linear(in_features=model_param.hidden_dim, out_features=128),
            nn.Linear(in_features=128, out_features=model_param.output_dim)
        )

    def forward(self, x, labels=None):
        '''
            :param x: torch.tensor, of shape [batch_size, max_seq_len].
            :param labels: torch.tensor, of shape [batch_size]. Not used in this model.
            :return outputs: torch.tensor, of shape [batch_size, output_dim].
        '''

        # [batch_size, max_seq_len, embed_dim].
        features = self.dropout(self.embedding(x))
        
        # [batch_size, max_seq_len, hidden_dim].
        outputs, _ = self.rnn(features)

        # [batch_size, hidden_dim].
        outputs = outputs[:, -1, :]

        return self.fc(self.dropout(outputs))

I'm using nn.CrossEntropyLoss() for loss function, and optim.SGD for optimizer. The definition of loss function and optimizer is given like this.

# Loss function and optimizer.
loss_func = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=0.9)

And my training procedure is roughly shown as below.

            for batch in train_iter:

                optimizer.zero_grad()

                # The prediction of model, and its corresponding loss.
                prediction = model(batch.text.type(torch.LongTensor).to(device), batch.label.to(device))
                loss = loss_func(prediction, batch.label.to(device))

                loss.backward()
                optimizer.step()

                # Record total loss.
                epoch_losses.append(loss.item() / batch_size)

When I'm training this model, validate accuracy and losses are reported like this.

Epoch 1/300 valid acc: [0.839] (16668 in 19873), time spent 631.497 sec. Validate loss 1.506138. Best validate epoch is 1.
Epoch 2/300 valid acc: [0.839] (16668 in 19873), time spent 627.631 sec. Validate loss 1.577007. Best validate epoch is 2.
Epoch 3/300 valid acc: [0.839] (16668 in 19873), time spent 631.427 sec. Validate loss 1.580756. Best validate epoch is 3.
Epoch 4/300 valid acc: [0.839] (16668 in 19873), time spent 605.352 sec. Validate loss 1.581306. Best validate epoch is 4.
Epoch 5/300 valid acc: [0.839] (16668 in 19873), time spent 388.487 sec. Validate loss 1.581431. Best validate epoch is 5.
Epoch 6/300 valid acc: [0.839] (16668 in 19873), time spent 360.344 sec. Validate loss 1.581464. Best validate epoch is 6.
Epoch 7/300 valid acc: [0.839] (16668 in 19873), time spent 624.345 sec. Validate loss 1.581473. Best validate epoch is 7.
Epoch 8/300 valid acc: [0.839] (16668 in 19873), time spent 622.059 sec. Validate loss 1.581477. Best validate epoch is 8.
Epoch 9/300 valid acc: [0.839] (16668 in 19873), time spent 651.425 sec. Validate loss 1.581478. Best validate epoch is 9.
Epoch 10/300 valid acc: [0.839] (16668 in 19873), time spent 697.475 sec. Validate loss 1.581478. Best validate epoch is 10.
...

It shows that validate loss does not decrease after epoch 9, and validate accuracy keeps unchanged since the first epoch ( Note that in my dataset, one of the labels accounted for 83%, it could be inferred from this that my model tends to prediction all sequences to the same label, but this also happens when I'm training on another dataset that is relatively less unbalanced ). Are there anybody that has met this situation B4? I'm wondering if I have made a mistake in designing model or training procedure. Thanks for your help XD.

Updated on Nov.19th, I have added a figure which shows how loss behaves while training. It can be known from this figure that both training loss and validating loss turned to be constant after 5th epoch. training and validating loss in 20 epochs

Now I found that loss does not drop down mainly because weight decay set in optimizer is to high .

optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=0.9)

So I fixed this and changed weight decay to be 5e-5.

optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=5e-5)

This time the loss of my network begins to decrease. However, there is no improvement in accuracy.

Epoch 1/100 valid acc: [0.839] (16668 in 19873), time spent 398.154 sec. Validate loss 0.713456. Best validate epoch is 1.
Epoch 2/100 valid acc: [0.839] (16668 in 19873), time spent 572.057 sec. Validate loss 0.631721. Best validate epoch is 2.
Epoch 3/100 valid acc: [0.839] (16668 in 19873), time spent 580.867 sec. Validate loss 0.613186. Best validate epoch is 3.
Epoch 4/100 valid acc: [0.839] (16668 in 19873), time spent 561.953 sec. Validate loss 0.601883. Best validate epoch is 4.
Epoch 5/100 valid acc: [0.839] (16668 in 19873), time spent 564.913 sec. Validate loss 0.596573. Best validate epoch is 5.
Epoch 6/100 valid acc: [0.839] (16668 in 19873), time spent 574.525 sec. Validate loss 0.592848. Best validate epoch is 6.
Epoch 7/100 valid acc: [0.839] (16668 in 19873), time spent 580.885 sec. Validate loss 0.591074. Best validate epoch is 7.
Epoch 8/100 valid acc: [0.839] (16668 in 19873), time spent 455.228 sec. Validate loss 0.589787. Best validate epoch is 8.
Epoch 9/100 valid acc: [0.839] (16668 in 19873), time spent 582.756 sec. Validate loss 0.588691. Best validate epoch is 9.
Epoch 10/100 valid acc: [0.839] (16668 in 19873), time spent 583.997 sec. Validate loss 0.588260. Best validate epoch is 10.
Epoch 11/100 valid acc: [0.839] (16668 in 19873), time spent 599.630 sec. Validate loss 0.588224. Best validate epoch is 11.
Epoch 12/100 valid acc: [0.839] (16668 in 19873), time spent 597.713 sec. Validate loss 0.586977. Best validate epoch is 12.
Epoch 13/100 valid acc: [0.839] (16668 in 19873), time spent 605.038 sec. Validate loss 0.587937. Best validate epoch is 13.
Epoch 14/100 valid acc: [0.839] (16668 in 19873), time spent 598.712 sec. Validate loss 0.587059. Best validate epoch is 14.
Epoch 15/100 valid acc: [0.839] (16668 in 19873), time spent 409.344 sec. Validate loss 0.587293. Best validate epoch is 15.
...

I'm wondering if learning rate of 1e-3 and weight decay of 5e-5 are reasonable settings . My designated size of batch is 32.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM