繁体   English   中英

如何在 Detectron2 中设置自定义类权重

[英]How set custom class weights in Detectron2

我将https://detectron2.readthedocs.io/tutorials/install.html用于我的数据集以及其他类和对象。

而且我的数据集是不平衡的。 我希望为每个班级设置不同的权重。 我怎样才能做到这一点?

不幸的是,如果不编写自己的组件,还没有办法配置它。

一种快速执行此操作的方法是编写一个新的头部,该头部继承自包含您的损失的头部。 然后,损失将被替换为使用您的损失权重初始化的新损失对象。

例如在DeepLabV3+的情况下,这看起来像这样:

import torch
from torch import nn
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.projects.deeplab import DeepLabCE, DeepLabV3PlusHead

@SEM_SEG_HEADS_REGISTRY.register()
class MyNewHead(DeepLabV3PlusHead):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        weight = torch.Tensor([0.4, 0.6])  # Adapt to your case
        if self.loss_type == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(
                reduction="mean", ignore_index=self.ignore_value, weight=weight
            )
        elif self.loss_type == "hard_pixel_mining":
            self.loss = DeepLabCE(
                ignore_label=self.ignore_value,
                top_k_percent_pixels=0.2,
                weight=weight,
            )
        else:
            raise ValueError("Unexpected loss type: %s" % self.loss_type)

然后,您修改配置文件以选择新头:

MODEL:  
  SEM_SEG_HEAD:
    NAME: "MyNewHead"

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM