简体   繁体   中英

BertForSequenceClassification vs. BertForMultipleChoice for sentence multi-class classification

I'm working on a text classification problem (eg sentiment analysis), where I need to classify a text string into one of five classes.

I just started using the Huggingface Transformer package and BERT with PyTorch. What I need is a classifier with a softmax layer on top so that I can do 5-way classification. Confusingly, there seem to be two relevant options in the Transformer package: BertForSequenceClassification and BertForMultipleChoice .

Which one should I use for my 5-way classification task? What are the appropriate use cases for them?

The documentation for BertForSequenceClassification doesn't mention softmax at all, although it does mention cross-entropy. I am not sure if this class is only for 2-class classification (ie logistic regression).

Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) eg for GLUE tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

The documentation for BertForMultipleChoice mentions softmax, but the way the labels are described, it sound like this class is for multi-label classification (that is, a binary classification for multiple labels).

Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) eg for RocStories/SWAG tasks.

  • labels (torch.LongTensor of shape (batch_size,), optional, defaults to None) – Labels for computing the multiple choice classification loss. Indices should be in [0, ..., num_choices] where num_choices is the size of the second dimension of the input tensors.

Thank you for any help.

The answer to this lies in the (admittedly very brief) description of what the tasks are about:

[ BertForMultipleChoice ] [...], eg for RocStories/SWAG tasks.

When looking at the paper for SWAG , it seems that the task is actually learning to choose from varying options . This is in contrast to your "classical" classification task, in which the "choices" (ie, classes) do not vary across your samples, which is exactly what BertForSequenceClassification is for.

Both variants can in fact be for an arbitrary number of classes (in the case of BertForSequenceClassification ), respectively choices (for BertForMultipleChoice ), via changing the labels parameter in the config. But, since it seems like you are dealing with a case of "classical classification", I suggest using the BertForSequenceClassification model.

Shortly addressing the missing Softmax in BertForSequenceClassification : Since classification tasks can compute loss across classes indipendent of the sample (unlike multiple choice, where your distribution is changing), this allows you to use Cross-Entropy Loss, which factors in Softmax in the backpropagation step for increased numerical stability .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM