简体   繁体   English

BERT 总是预测同一个类(微调)

[英]BERT always predicts same class (Fine-Tuning)

I am fine-tuning BERT on a financial news dataset.我正在金融新闻数据集上微调 BERT。 Unfortunately BERT seems to be trapped in a local minimum.不幸的是,BERT 似乎陷入了局部最小值。 It is content with learning to always predict the same class.它满足于学习总是预测同一个班级。

  • balancing the dataset didnt work平衡数据集不起作用
  • tuning parameters didnt work as well调整参数也不起作用

I am honestly not sure what is causing this problem.老实说,我不确定是什么导致了这个问题。 With the simpletransformers library I am getting very good results.使用 simpletransformers 库,我得到了非常好的结果。 I would really appreciate if somebody could help me.如果有人可以帮助我,我将不胜感激。 thanks a lot!多谢!

Full code on github: https://github.com/Bene939/BERT_News_Sentiment_Classifier github上的完整代码: https ://github.com/Bene939/BERT_News_Sentiment_Classifier

Code:代码:

from transformers import BertForSequenceClassification, AdamW, BertTokenizer, get_linear_schedule_with_warmup, Trainer, TrainingArguments
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd
from pathlib import Path
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from torch.nn import functional as F
from collections import defaultdict
import random


#defining tokenizer, model and optimizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=3)


if torch.cuda.is_available():
  print("\nUsing: ", torch.cuda.get_device_name(0))
  device = torch.device('cuda')
else:
  print("\nUsing: CPU")
  device = torch.device('cpu')
model = model.to(device)


#loading dataset
labeled_dataset = "news_headlines_sentiment.csv"
labeled_dataset_file = Path(labeled_dataset)
file_loaded = False
while not file_loaded:
  if labeled_dataset_file.exists():
    labeled_dataset = pd.read_csv(labeled_dataset_file)
    file_loaded = True
    print("Dataset Loaded")
  else:
    print("File not Found")
print(labeled_dataset)

#counting sentiments
negative = 0
neutral = 0
positive = 0
for idx, row in labeled_dataset.iterrows():
  if row["sentiment"] == 0:
    negative += 1
  elif row["sentiment"] == 1:
    neutral += 1
  else:
    positive += 1
print("Unbalanced Dataset")
print("negative: ", negative)
print("neutral: ", neutral)
print("positive: ", positive)

#balancing dataset to 1/3 per sentiment
for idx, row in labeled_dataset.iterrows():
  if row["sentiment"] == 0:
    if negative - neutral != 0:
      index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
      labeled_dataset.drop(index_name, inplace=True)
      negative -= 1
  elif row["sentiment"] == 2:
    if positive - neutral != 0:
      index_name = labeled_dataset[labeled_dataset["news"] == row["news"]].index
      labeled_dataset.drop(index_name, inplace=True)
      positive -= 1

#custom dataset class
class NewsSentimentDataset(torch.utils.data.Dataset):
  def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

  def __getitem__(self, idx):
      item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
      item['labels'] = torch.tensor(self.labels[idx])
      return item

  def __len__(self):
      return len(self.labels)

#method for tokenizing dataset list
def tokenize_headlines(headlines, labels, tokenizer):

  encodings = tokenizer.batch_encode_plus(
      headlines,
      add_special_tokens = True,
      truncation = True,
      padding = 'max_length',
      return_attention_mask = True,
      return_token_type_ids = True
  )

  dataset = NewsSentimentDataset(encodings, labels)
  return dataset

#splitting dataset into training and validation set
#load news sentiment dataset
all_headlines = labeled_dataset['news'].tolist()
all_labels = labeled_dataset['sentiment'].tolist()

train_headlines, val_headlines, train_labels, val_labels = train_test_split(all_headlines, all_labels, test_size=.2)

val_dataset = tokenize_headlines(val_headlines, val_labels, tokenizer)
train_dataset = tokenize_headlines(train_headlines, val_labels, tokenizer)

#data loader
train_batch_size = 8
val_batch_size = 8

train_data_loader = DataLoader(train_dataset, batch_size = train_batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size = val_batch_size, sampler=SequentialSampler(val_dataset))

#optimizer and scheduler
num_epochs = 1
num_steps = len(train_data_loader) * num_epochs
optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_steps*0.06, num_training_steps=num_steps)

#training and evaluation
seed_val = 64

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

for epoch in range(num_epochs):

  print("\n###################################################")
  print("Epoch: {}/{}".format(epoch+1, num_epochs))
  print("###################################################\n")

  #training phase
 
  average_train_loss = 0
  average_train_acc = 0
  model.train() 
  for step, batch in enumerate(train_data_loader):
      
      
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels = batch['labels'].to(device)
      token_type_ids = batch['token_type_ids'].to(device)


      outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)

      loss = F.cross_entropy(outputs[0], labels)
      average_train_loss += loss

      if step % 40 == 0:
        print("Training Loss: ", loss)

      logits = outputs[0].detach().cpu().numpy()
      label_ids = labels.to('cpu').numpy()

      average_train_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))
      print("predictions: ",np.argmax(logits, axis=1))
      print("labels:      ",label_ids)
      print("#############")
      optimizer.zero_grad()
      loss.backward()
      #maximum gradient clipping
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      
      optimizer.step()
      scheduler.step()
      model.zero_grad()

  average_train_loss = average_train_loss / len(train_data_loader)
  average_train_acc = average_train_acc / len(train_data_loader)
  print("======Average Training Loss: {:.5f}======".format(average_train_loss))
  print("======Average Training Accuracy: {:.2f}%======".format(average_train_acc*100))

  #validation phase
  average_val_loss = 0
  average_val_acc = 0
  model.eval()
  for step,batch in enumerate(val_data_loader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)

    pred = []
    with torch.no_grad():
      

      outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

      loss = F.cross_entropy(outputs[0], labels)
      average_val_loss += loss

      logits = outputs[0].detach().cpu().numpy()
      label_ids = labels.to('cpu').numpy()
      print("predictions: ",np.argmax(logits, axis=1))
      print("labels:      ",label_ids)
      print("#############")

      average_val_acc += sklearn.metrics.accuracy_score(label_ids, np.argmax(logits, axis=1))

  average_val_loss = average_val_loss / len(val_data_loader)
  average_val_acc = average_val_acc / len(val_data_loader)

  print("======Average Validation Loss: {:.5f}======".format(average_val_loss))
  print("======Average Validation Accuracy: {:.2f}%======".format(average_val_acc*100))
###################################################
Epoch: 1/1
###################################################

Training Loss:  tensor(1.1006, device='cuda:0', grad_fn=<NllLossBackward>)
predictions:  [1 0 2 0 0 0 2 0]
labels:       [2 0 1 1 0 1 0 1]
#############
predictions:  [2 2 0 0 0 2 0 0]
labels:       [1 2 1 0 2 0 1 2]
#############
predictions:  [0 0 0 0 1 0 0 1]
labels:       [0 1 1 0 1 1 2 0]
#############
predictions:  [0 0 0 2 0 1 0 0]
labels:       [0 0 0 2 0 0 2 1]
#############
predictions:  [1 0 0 0 0 0 2 0]
labels:       [0 2 2 1 0 0 0 0]
#############
predictions:  [0 0 0 0 0 1 0 0]
labels:       [1 0 2 2 2 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 2 2 2 0 2 0]
#############
predictions:  [0 1 0 0 0 0 0 0]
labels:       [2 2 0 2 0 0 0 1]
#############
predictions:  [0 0 0 0 0 2 0 1]
labels:       [0 1 0 2 2 0 1 2]
#############
predictions:  [0 0 2 0 0 0 1 0]
labels:       [0 0 0 1 2 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 1 0 1 0 1 1]
#############
predictions:  [0 2 0 0 0 0 0 0]
labels:       [2 2 0 1 0 1 2 1]
#############
predictions:  [0 1 0 0 0 0 1 2]
labels:       [2 2 1 0 2 0 0 2]
#############
predictions:  [0 0 1 1 1 1 0 1]
labels:       [1 2 1 1 1 1 2 2]
#############
predictions:  [1 0 0 0 0 1 2 1]
labels:       [1 0 1 1 0 0 0 2]
#############
predictions:  [0 1 1 1 1 0 2 1]
labels:       [2 2 1 2 2 1 1 2]
#############
predictions:  [0 0 1 0 1 1 0 0]
labels:       [1 0 0 1 0 1 0 2]
#############
predictions:  [1 2 0 0 1 2 0 0]
labels:       [0 2 2 1 2 0 1 0]
#############
predictions:  [0 2 1 1 0 1 1 0]
labels:       [2 2 0 1 1 0 1 2]
#############
predictions:  [1 0 1 1 1 1 1 0]
labels:       [0 2 0 1 0 1 2 2]
#############
predictions:  [0 2 1 2 0 0 1 1]
labels:       [2 1 1 1 1 2 2 0]
#############
predictions:  [0 1 2 2 2 1 1 2]
labels:       [2 2 1 1 2 1 0 1]
#############
predictions:  [2 2 2 1 2 1 1 1]
labels:       [0 1 1 0 0 2 2 1]
#############
predictions:  [1 2 2 2 1 2 1 2]
labels:       [0 0 0 0 2 0 1 2]
#############
predictions:  [2 1 1 1 2 2 2 2]
labels:       [1 0 2 2 1 0 0 0]
#############
predictions:  [2 1 2 2 2 1 2 2]
labels:       [2 1 1 1 1 1 2 2]
#############
predictions:  [1 1 0 2 1 2 1 2]
labels:       [2 2 0 2 0 1 2 0]
#############
predictions:  [0 1 1 2 0 1 2 1]
labels:       [2 2 2 1 2 2 0 1]
#############
predictions:  [2 1 1 1 1 2 1 1]
labels:       [0 1 1 2 1 0 0 2]
#############
predictions:  [1 2 2 0 1 1 1 2]
labels:       [0 1 2 1 2 1 0 1]
#############
predictions:  [0 1 1 1 1 1 1 0]
labels:       [0 2 0 1 1 2 2 2]
#############
predictions:  [1 2 1 1 2 1 1 0]
labels:       [0 2 2 2 0 0 1 0]
#############
predictions:  [2 2 2 1 2 1 1 2]
labels:       [2 2 1 2 1 0 0 0]
#############
predictions:  [2 2 1 2 2 2 1 2]
labels:       [1 1 2 2 2 0 2 1]
#############
predictions:  [2 2 2 2 2 0 2 2]
labels:       [2 2 1 2 0 1 1 2]
#############
predictions:  [1 1 2 1 2 2 0 1]
labels:       [2 1 1 1 0 0 2 2]
#############
predictions:  [2 1 2 2 2 2 1 0]
labels:       [0 2 0 2 0 0 0 0]
#############
predictions:  [2 2 2 2 2 2 2 2]
labels:       [1 1 0 2 0 1 2 1]
#############
predictions:  [2 2 2 2 1 2 2 2]
labels:       [1 0 0 1 1 0 0 0]
#############
predictions:  [2 2 2 1 2 2 2 2]
labels:       [1 0 1 1 0 2 2 0]
#############
Training Loss:  tensor(1.1104, device='cuda:0', grad_fn=<NllLossBackward>)
predictions:  [2 0 1 2 1 2 2 0]
labels:       [2 2 0 0 1 0 0 2]
#############
predictions:  [0 2 2 0 2 1 1 1]
labels:       [0 0 0 1 0 0 1 0]
#############
predictions:  [0 2 2 0 1 1 1 2]
labels:       [2 1 1 1 2 2 1 0]
#############
predictions:  [2 1 1 2 2 0 2 0]
labels:       [1 2 1 2 1 0 2 1]
#############
predictions:  [0 2 2 0 0 2 1 2]
labels:       [0 0 2 2 0 0 2 0]
#############
predictions:  [0 0 1 2 2 0 2 2]
labels:       [0 0 0 0 0 0 0 0]
#############
predictions:  [1 1 2 1 2 0 1 2]
labels:       [0 0 2 0 0 0 1 1]
#############
predictions:  [0 0 2 1 0 2 0 1]
labels:       [1 1 2 1 1 0 2 0]
#############
predictions:  [0 0 0 0 1 0 0 0]
labels:       [2 2 1 1 2 1 1 1]
#############
predictions:  [0 0 0 0 1 0 0 0]
labels:       [1 1 2 2 1 1 2 0]
#############
predictions:  [0 0 0 0 0 1 1 1]
labels:       [2 0 1 1 0 1 2 2]
#############
predictions:  [0 0 1 0 0 1 2 1]
labels:       [1 2 0 2 2 0 2 1]
#############
predictions:  [1 1 1 1 0 1 0 1]
labels:       [2 0 1 0 1 0 1 2]
#############
predictions:  [1 2 2 0 0 0 1 1]
labels:       [2 0 0 2 1 2 2 2]
#############
predictions:  [1 0 2 1 0 2 2 0]
labels:       [0 0 2 1 2 1 1 1]
#############
predictions:  [0 0 0 1 1 1 1 1]
labels:       [1 2 1 0 0 0 1 0]
#############
predictions:  [1 1 1 0 1 1 0 1]
labels:       [0 2 1 2 1 2 2 0]
#############
predictions:  [2 1 0 1 1 2 0 0]
labels:       [0 1 0 0 1 2 0 2]
#############
predictions:  [0 1 1 0 0 1 0 1]
labels:       [1 0 0 2 2 1 1 2]
#############
predictions:  [1 1 1 1 1 1 1 1]
labels:       [2 0 1 0 2 0 0 2]
#############
predictions:  [1 0 0 1 0 1 0 2]
labels:       [1 0 0 1 1 2 2 1]
#############
predictions:  [1 1 1 1 1 1 0 0]
labels:       [1 1 0 2 1 0 2 0]
#############
predictions:  [1 1 2 1 0 1 0 0]
labels:       [0 2 1 2 1 1 0 2]
#############
predictions:  [1 1 0 0 1 2 1 1]
labels:       [0 2 1 0 2 2 0 1]
#############
predictions:  [0 1 1 0 0 1 0 1]
labels:       [0 0 1 2 2 0 1 2]
#############
predictions:  [1 0 2 2 2 1 1 0]
labels:       [2 2 1 0 0 1 1 2]
#############
predictions:  [1 2 2 1 1 2 1 1]
labels:       [1 0 0 1 0 0 0 0]
#############
predictions:  [0 2 0 2 2 0 2 2]
labels:       [2 0 0 0 2 1 1 2]
#############
predictions:  [0 0 1 0 1 0 2 2]
labels:       [0 0 1 0 1 0 2 0]
#############
predictions:  [0 2 0 1 1 2 2 0]
labels:       [0 2 0 2 0 2 0 0]
#############
predictions:  [2 2 2 2 2 2 2 1]
labels:       [2 2 1 1 0 0 2 2]
#############
predictions:  [2 0 0 2 2 1 1 0]
labels:       [1 0 0 1 0 2 1 2]
#############
predictions:  [2 0 0 2 0 2 2 0]
labels:       [2 2 2 2 0 1 1 1]
#############
predictions:  [0 2 2 0 2 2 0 0]
labels:       [1 0 1 2 0 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 2]
labels:       [2 1 1 0 0 0 1 2]
#############
predictions:  [2 0 2 0 2 1 0 2]
labels:       [2 1 1 2 1 1 0 0]
#############
predictions:  [1 1 2 0 2 0 2 2]
labels:       [0 2 1 2 1 2 1 0]
#############
predictions:  [2 0 1 1 0 2 0 0]
labels:       [2 1 0 1 1 0 2 0]
#############
predictions:  [2 0 0 2 0 2 1 0]
labels:       [0 0 0 0 2 1 0 1]
#############
predictions:  [1 2 1 0 0 2 0 2]
labels:       [2 0 2 1 0 0 1 1]
#############
Training Loss:  tensor(1.1162, device='cuda:0', grad_fn=<NllLossBackward>)
predictions:  [2 0 0 1 1 1 0 1]
labels:       [0 1 1 1 1 2 2 1]
#############
predictions:  [0 2 0 1 2 0 0 1]
labels:       [2 2 1 0 1 0 0 0]
#############
predictions:  [0 0 1 0 0 0 0 1]
labels:       [1 0 2 0 0 2 2 0]
#############
predictions:  [2 1 2 2 0 1 2 0]
labels:       [2 0 1 0 2 1 0 1]
#############
predictions:  [1 0 0 2 0 0 1 1]
labels:       [2 2 0 2 0 2 0 0]
#############
predictions:  [0 0 1 0 0 0 0 0]
labels:       [2 2 2 1 2 2 2 2]
#############
predictions:  [0 0 1 1 0 1 1 0]
labels:       [2 1 1 1 0 2 1 0]
#############
predictions:  [0 0 0 1 0 0 1 0]
labels:       [2 0 2 2 0 0 1 2]
#############
predictions:  [1 0 1 0 0 2 0 0]
labels:       [1 1 2 0 0 1 0 0]
#############
predictions:  [2 1 0 0 0 1 0 0]
labels:       [1 2 0 0 0 0 0 0]
#############
predictions:  [0 2 0 0 0 0 0 0]
labels:       [2 0 1 1 2 2 1 1]
#############
predictions:  [0 1 0 0 0 1 0 2]
labels:       [0 2 1 1 0 0 1 2]
#############
predictions:  [0 2 1 0 0 1 1 1]
labels:       [1 1 0 2 0 1 1 0]
#############
predictions:  [0 1 1 0 0 0 1 0]
labels:       [0 0 1 0 1 2 1 1]
#############
predictions:  [0 1 1 0 1 0 0 0]
labels:       [0 1 1 1 2 2 2 0]
#############
predictions:  [0 0 0 0 1 1 0 0]
labels:       [2 0 2 2 1 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 0 2 2 0 1 1]
#############
predictions:  [0 1 0 0 0 0 0 0]
labels:       [0 2 0 1 1 2 0 2]
#############
predictions:  [1 1 0 1 0 1 0 2]
labels:       [1 2 0 0 2 2 2 1]
#############
predictions:  [1 1 0 0 0 1 2 1]
labels:       [0 0 1 2 2 1 2 2]
#############
predictions:  [1 1 1 0 1 1 2 0]
labels:       [0 0 0 2 0 1 0 2]
#############
predictions:  [0 1 0 0 1 1 2 1]
labels:       [2 0 0 1 2 2 1 2]
#############
predictions:  [1 0 0 0 1 0 0 1]
labels:       [1 2 2 2 2 1 0 1]
#############
predictions:  [2 0 0 0 0 0 0 0]
labels:       [1 2 0 2 2 1 1 1]
#############
predictions:  [2 0 1 1 0 0 1 0]
labels:       [0 0 0 0 2 2 1 1]
#############
predictions:  [2 0 0 1 0 0 1 1]
labels:       [2 2 1 1 0 0 1 0]
#############
predictions:  [1 1 1 1 1 2 0 0]
labels:       [0 0 2 1 0 0 0 0]
#############
predictions:  [1 1 2 0 1 2 0 1]
labels:       [0 2 1 0 2 0 0 1]
#############
predictions:  [0 0 2 1 0 2 0 1]
labels:       [1 2 0 2 2 1 0 0]
#############
predictions:  [0 0 2 0 2 1 1 2]
labels:       [2 2 1 2 2 2 0 0]
#############
predictions:  [0 1 0 0 0 0 2 1]
labels:       [1 1 0 1 1 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 0 0 2 0 0 2]
#############
predictions:  [2 2 2 0 1 1 1 0]
labels:       [1 0 2 1 1 2 0 0]
#############
predictions:  [0 0 1 0 0 0 2 0]
labels:       [0 1 2 1 1 0 0 0]
#############
predictions:  [0 2 0 1 0 2 0 0]
labels:       [0 0 2 1 1 0 2 2]
#############
predictions:  [0 0 1 2 0 2 0 1]
labels:       [2 2 0 0 0 2 2 2]
#############
predictions:  [1 0 0 0 2 0 0 1]
labels:       [2 0 1 1 1 0 0 1]
#############
predictions:  [0 1 0 0 0 0 0 2]
labels:       [1 1 1 0 0 0 2 2]
#############
predictions:  [0 2 0 1 0 2 0 0]
labels:       [1 1 1 1 2 2 1 0]
#############
predictions:  [1 2 0 0 0 0 0 0]
labels:       [2 0 2 1 0 1 1 1]
#############
Training Loss:  tensor(1.2082, device='cuda:0', grad_fn=<NllLossBackward>)
predictions:  [0 2 0 0 0 0 2 0]
labels:       [1 0 2 1 2 2 1 1]
#############
predictions:  [2 0 0 0 0 0 1 0]
labels:       [1 0 0 0 0 2 1 0]
#############
predictions:  [0 0 0 0 2 1 1 1]
labels:       [0 2 2 0 1 2 1 1]
#############
predictions:  [2 1 0 1 0 0 2 0]
labels:       [1 0 2 1 0 2 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 0 0 0 0 1 0]
#############
predictions:  [0 2 1 0 0 0 1 1]
labels:       [0 2 2 2 2 1 1 0]
#############
predictions:  [0 0 0 1 1 0 0 1]
labels:       [0 1 0 1 2 2 2 2]
#############
predictions:  [0 0 0 1 1 1 1 2]
labels:       [2 2 1 2 0 1 1 1]
#############
predictions:  [0 1 2 0 0 1 0 0]
labels:       [0 2 1 0 0 1 0 0]
#############
predictions:  [1 1 1 1 0 0 0 0]
labels:       [2 1 2 1 0 2 2 1]
#############
predictions:  [0 1 2 0 0 1 1 0]
labels:       [2 0 2 1 1 1 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 0 1 1 0 0]
#############
predictions:  [0 0 0 0 0 1 2 2]
labels:       [2 2 1 1 0 2 1 2]
#############
predictions:  [0 1 0 0 1 1 0 1]
labels:       [0 1 0 2 1 0 0 1]
#############
predictions:  [0 2 2 0 0 0 0 2]
labels:       [0 0 2 1 2 2 0 1]
#############
predictions:  [2 0 0 2 2 0 2 0]
labels:       [2 1 0 2 2 0 1 0]
#############
predictions:  [0 2 2 0 2 1 1 2]
labels:       [1 1 0 0 2 1 0 0]
#############
predictions:  [1 1 2 2 0 0 1 2]
labels:       [2 0 2 0 1 1 1 1]
#############
predictions:  [0 1 1 0 0 1 1 0]
labels:       [0 2 1 0 0 2 2 0]
#############
predictions:  [2 1 0 0 0 0 1 1]
labels:       [0 2 0 2 0 0 1 1]
#############
predictions:  [1 2 0 1 2 0 0 0]
labels:       [1 0 1 1 0 2 2 2]
#############
predictions:  [0 0 0 0 2 2 1 2]
labels:       [2 2 2 1 1 1 1 0]
#############
predictions:  [1 2 0 1 0 0 2 0]
labels:       [2 2 1 1 1 0 2 0]
#############
predictions:  [2 0 0 0 0 2 1]
labels:       [0 1 1 2 2 0 2]
#############
======Average Training Loss: 1.11279======
======Average Training Accuracy: 33.77%======
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 1 1 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 0 2 1 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 2 2 1 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 1 2 0 1 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 0 0 1 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 1 2 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 1 2 0 2 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 1 2 2 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 2 2 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 0 2 0 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 1 1 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 2 2 0 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 0 1 2 2 1]
#############
predictions:  [0 0 0 1 0 0 0 0]
labels:       [0 0 1 1 0 2 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 2 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 1 2 2 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 2 1 2 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 0 2 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 0 1 0 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 2 2 2 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 1 1 0 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 1 1 2 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 2 1 2 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 1 0 2 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 1 1 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 0 1 2 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 1 1 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 1 0 0 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 0 0 0 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 1 1 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 1 2 1 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 0 1 1 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 1 0 1 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 2 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 0 2 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 1 1 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 1 2 2 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 0 0 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 0 1 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 2 1 1 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 2 2 2 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 2 2 1 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 2 2 2 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 0 0 1 0 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 0 0 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 1 2 0 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 2 0 1 2 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 2 0 0 0 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 1 0 0 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 0 1 1 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 0 0 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 2 1 1 1 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 0 2 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 2 1 0 2 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 1 2 2 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 0 0 2 1 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 2 0 2 1 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 2 0 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 0 0 1 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 2 2 0 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 1 1 1 0 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 2 2 1 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 0 2 0 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 1 1 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 1 1 1 1 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 2 1 0 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 2 1 0 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 2 0 0 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 2 0 0 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 0 2 2 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 0 1 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 1 2 0 1 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 0 0 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 1 2 0 2 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 1 0 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 0 1 0 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 0 0 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 1 2 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 2 0 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 0 1 1 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 0 1 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 1 2 0 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 1 1 1 0 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 1 2 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 0 1 1 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 0 2 1 0 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 0 0 0 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 1 2 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 0 1 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 1 0 2 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 1 2 0 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 2 2 2 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 1 2 0 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 1 1 0 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 2 2 2 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 0 0 1 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 2 1 2 1 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 0 0 0 2 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 1 1 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 1 0 2 2 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 1 1 1 2 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 0 1 0 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 2 2 0 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 1 2 2 2 1 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 0 1 0 2 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 2 1 0 2 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 2 0 2 2 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 2 0 0 1 0 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 1 0 0 0 2 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 2 0 1 2 1 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 2 2 2 2 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 0 1 2 0 2 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 1 1 1 1 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 0 0 0 1 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 1 2 0 1 2 2 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 1 1 1 2 1 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [1 0 1 1 1 0 0 2]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 2 0 0 0 0 1 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [0 0 1 1 2 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 1 1 1 0 1 0 0]
#############
predictions:  [0 0 0 0 0 0 0 0]
labels:       [2 0 2 2 2 0 0 1]
#############
predictions:  [0 0 0 0 0 0 0]
labels:       [2 2 1 1 0 0 1]
#############
======Average Validation Loss: 1.09527======
======Average Validation Accuracy: 35.53%======

I want to leave an answer here for people that are struggling with a similar issue.我想在这里为那些在类似问题上苦苦挣扎的人留下一个答案。

  1. Try different learning rates.尝试不同的学习率。 The learning rate is most likely too high, try to lower it.学习率很可能太高,尝试降低它。 (Worked for me) (为我工作)
  2. Reduce the number of Epochs while training.在训练时减少 Epoch 的数量。 This problem is related to fine-tuning.这个问题与微调有关。 If the dataset is extensive, 1-2 epochs may be enough.如果数据集很广泛,1-2 个 epoch 可能就足够了。
  3. Try out different batch sizes or introduce dropout layers.尝试不同的批量大小或引入 dropout 层。

For a lengthy discussion see: https://github.com/ThilinaRajapakse/simpletransformers/issues/234有关冗长的讨论,请参阅: https ://github.com/ThilinaRajapakse/simpletransformers/issues/234

For multi-class classification/sentiment analysis using BERT the 'neutral' class HAS TO BE 2!!对于使用 BERT 的多类分类/情感分析,“中性”类必须是 2! It CANNOT be between 'negative' = 0 and 'positive' = 2它不能介于 'negative' = 0 和 'positive' = 2 之间

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM