繁体   English   中英

在 PyTorch 中对不平衡数据集使用 Focal Loss

[英]Using Focal Loss for imbalanced dataset in PyTorch

我在 GitHub 中发现了这种focal loss的实现,我将它用于不平衡数据集二元分类问题。

# IMPLEMENTATION CREDIT: https://github.com/clcarwin/focal_loss_pytorch
    class FocalLoss(nn.Module):
    def __init__(self, gamma=0.5, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

gamma=args.gamma
alpha=args.alpha

criterion = FocalLoss(gamma, alpha)
m = nn.Sigmoid()

我在训练阶段使用如下标准:

for i_batch, sample_batched in enumerate(dataloader_train):  
            #pdb.set_trace()        
            feats = torch.stack(sample_batched['image']) 
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            print('feats shape: ', feats.shape)
            print('labels shape: ', labels.shape)
            output = model(feats)
            loss = criterion(m(output[:,1]-output[:,0]), labels.float())

错误是:

train: True test: False
preparing datasets and dataloaders......
creating models......

=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
training...
feats shape:  torch.Size([64, 419, 512])
labels shape:  torch.Size([64])
main_classifier.py:86: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  logpt = F.log_softmax(input)
Traceback (most recent call last):
  File "main_classifier.py", line 346, in <module>
    loss = criterion(m(output[:,1]-output[:,0]), labels.float())
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "main_classifier.py", line 87, in forward
    logpt = logpt.gather(1,target)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

我应该如何修复此错误?

FocalLoss 的这个实现是否正确?

BCEWithLogitLoss不同,输入与 CrossEntropyLoss 相同的CrossEntropyLoss解决了这个问题:

#loss = criterion(m(output[:,1]-output[:,0]), labels.float())
loss = criterion(output, labels)

来自 NVidia 的 Piotr 致谢

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM