简体   繁体   English

如何在 PyTorch 中使用具有焦点损失的类权重用于多类分类的不平衡数据集

[英]How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification

I am working on Multiclass Classification (4 classes) for Language Task and I am using the BERT model for classification task.我正在研究语言任务的多类分类(4 个类),我正在使用 BERT 模型进行分类任务。 I am following this blog as reference .我正在关注这个博客作为参考 My BERT Fine Tuned model returns nn.LogSoftmax(dim=1) .我的 BERT Fine Tuned 模型返回nn.LogSoftmax(dim=1)

My data is pretty imbalanced so I used sklearn.utils.class_weight.compute_class_weight to compute weights of the classes and used the weights inside the Loss.我的数据非常不平衡,所以我使用sklearn.utils.class_weight.compute_class_weight来计算类的权重并使用损失中的权重。

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

My results were not so good so I thought of Experementing with Focal Loss and have a code for Focal Loss.我的结果不太好,所以我想到了用Focal Loss并有一个 Focal Loss 的代码。

class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss

I have 3 questions now.我现在有3个问题。 First and the Most important is首先也是最重要的是

  1. Should I use Class Weight with Focal Loss?我应该使用带有焦点损失的类权重吗?
  2. If I have to Implement weights inside this Focal Loss , can I use weights parameters inside nn.CrossEntropyLoss()如果我必须在这个Focal Loss实现权重,我可以在nn.CrossEntropyLoss()使用weights参数nn.CrossEntropyLoss()
  3. If this implement is incorrect, what should be the proper code for this one including the weights (if possible)如果这个工具不正确,那么这个工具的正确代码应该是什么,包括权重(如果可能)

You may find answers to your questions as follows:您可以通过以下方式找到问题的答案:

  1. Focal loss automatically handles the class imbalance, hence weights are not required for the focal loss.焦点损失会自动处理类别不平衡,因此焦点损失不需要权重。 The alpha and gamma factors handle the class imbalance in the focal loss equation. alpha 和 gamma 因子处理焦点损失方程中的类不平衡。
  2. No need of extra weights because focal loss handles them using alpha and gamma modulating factors不需要额外的权重,因为焦点损失使用 alpha 和 gamma 调制因子处理它们
  3. The implementation you mentioned is correct according to the focal loss formula but I had trouble in causing my model to converge with this version hence, I used the following implementation from mmdetection framework根据焦点损失公式,您提到的实现是正确的,但是我无法使我的模型与此版本收敛,因此,我使用了 mmdetection 框架中的以下实现
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

You can also experiment with another focal loss version available您还可以尝试使用另一个可用的焦点损失版本

I think OP would've gotten his answer by now.我认为 OP 现在应该已经得到了他的答案。 I am writing this for other people who might ponder upon this.我写这篇文章是为了其他可能会思考这个问题的人。

There in one problem in OPs implementation of Focal Loss: Focal Loss 的 OP 实现存在一个问题:

  1. F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

In this line, the same alpha value is multiplied with every class output probability ie ( pt ).在这一行中,相同的alpha值乘以每个类别的输出概率,即 ( pt )。 Additionally, code doesn't show how we get pt .此外,代码没有显示我们如何获得pt A very good implementation of Focal Loss could be find here .可以在这里找到 Focal Loss 的一个很好的实现。 But this implementation is only for binary classification as it has alpha and 1-alpha for two classes in self.alpha tensor.但是此实现仅适用于二元分类,因为它在self.alpha张量中具有两个类的alpha1-alpha

In case of multi-class classification or multi-label classification, self.alpha tensor should contain number of elements equal to the total number of labels.在多类分类或多标签分类的情况下, self.alpha张量应包含等于标签总数的元素数。 The values could be inverse label frequency of labels or inverse label normalized frequency (just be cautious with labels which has 0 as frequency).这些值可以是标签的逆标签频率或逆标签归一化频率(对于频率为 0 的标签要小心)。

I think the implementation in your question is wrong.我认为您问题中的实施是错误的。 The alpha is the class weight. alpha 是类权重。

In cross entropy the class weight is the alpha_t as shown in the following expression:在交叉熵中,类权重是 alpha_t,如以下表达式所示:

在此处输入图片说明

you see that it is alpha_t rather than alpha.你会看到它是 alpha_t 而不是 alpha。

In focal loss the fomular is在焦点损失中,公式是
在此处输入图片说明

and we can see from this popular Pytorch implementation the alpha acts the same way as class weight.我们可以从这个流行的 Pytorch 实现中看到,alpha 的行为方式与类权重相同。

References:参考:

  1. https://amaarora.github.io/2020/06/29/FocalLoss.html#alpha-and-gamma https://amaarora.github.io/2020/06/29/FocalLoss.html#alpha-and-gamma
  2. https://github.com/clcarwin/focal_loss_pytorch https://github.com/clcarwin/focal_loss_pytorch

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM