简体   繁体   English

pytorch中的多类加权损失函数

[英]multi-class weighted loss function in pytorch

I am training a unet based model for multi-class segmentation task on pytorch framework.我正在 pytorch 框架上为多类分割任务训练一个基于 unet 的模型。 Optimizing the model with following loss function,使用以下损失函数优化模型,

class MulticlassJaccardLoss(_Loss):
"""Implementation of Jaccard loss for multiclass (semantic) image segmentation task
"""
__name__ = 'mc_jaccard_loss'
def __init__(self, classes: List[int] = None, from_logits=True, weight=None, reduction='elementwise_mean'):
    super(MulticlassJaccardLoss, self).__init__(reduction=reduction)
    self.classes = classes
    self.from_logits = from_logits
    self.weight = weight

def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
    """
    :param y_pred: NxCxHxW
    :param y_true: NxHxW
    :return: scalar
    """
    if self.from_logits:
        y_pred = y_pred.softmax(dim=1)

    n_classes = y_pred.size(1)
    smooth = 1e-3

    if self.classes is None:
        classes = range(n_classes)
    else:
        classes = self.classes
        n_classes = len(classes)

    loss = torch.zeros(n_classes, dtype=torch.float, device=y_pred.device)

    if self.weight is None:
        weights = [1] * n_classes
    else:
        weights = self.weight

    for class_index, weight in zip(classes, weights):

        jaccard_target = (y_true == class_index).float()
        jaccard_output = y_pred[:, class_index, ...]

        num_preds = jaccard_target.long().sum()

        if num_preds == 0:
            loss[class_index-1] = 0 #custom
        else:
            iou = soft_jaccard_score(jaccard_output, jaccard_target, from_logits=False, smooth=smooth)
            loss[class_index-1] = (1.0 - iou) * weight #custom

    if self.reduction == 'elementwise_mean':
        return loss.mean()

    if self.reduction == 'sum':
        return loss.sum()

    return loss

I am calculating loss for only two classes (class 1 and 2 and not for the background).我只计算两个类的损失(第 1 类和第 2 类,而不是背景)。

MulticlassJaccardLoss(weight=[0.5,10], classes=[1,2], from_logits=False)

When I train the model, it trains for first few iterations and I get the following error,当我训练模型时,它会训练前几次迭代,然后出现以下错误,

element 0 of tensors does not require grad and does not have a grad_fn

What is the mistake in the code?代码中有什么错误?

Thanks!谢谢!

Try setting:尝试设置:

torch.zeros(..., requires_grad=True)

I believe requires_grad=False is the default for torch.zeros, so this may help here.我相信 requires_grad=False 是 torch.zeros 的默认值,所以这可能会有所帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM