[英]How set custom class weights in Detectron2
我将https://detectron2.readthedocs.io/tutorials/install.html用于我的数据集以及其他类和对象。
而且我的数据集是不平衡的。 我希望为每个班级设置不同的权重。 我怎样才能做到这一点?
不幸的是,如果不编写自己的组件,还没有办法配置它。
一种快速执行此操作的方法是编写一个新的头部,该头部继承自包含您的损失的头部。 然后,损失将被替换为使用您的损失权重初始化的新损失对象。
例如在DeepLabV3+的情况下,这看起来像这样:
import torch
from torch import nn
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.projects.deeplab import DeepLabCE, DeepLabV3PlusHead
@SEM_SEG_HEADS_REGISTRY.register()
class MyNewHead(DeepLabV3PlusHead):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
weight = torch.Tensor([0.4, 0.6]) # Adapt to your case
if self.loss_type == "cross_entropy":
self.loss = nn.CrossEntropyLoss(
reduction="mean", ignore_index=self.ignore_value, weight=weight
)
elif self.loss_type == "hard_pixel_mining":
self.loss = DeepLabCE(
ignore_label=self.ignore_value,
top_k_percent_pixels=0.2,
weight=weight,
)
else:
raise ValueError("Unexpected loss type: %s" % self.loss_type)
然后,您修改配置文件以选择新头:
MODEL:
SEM_SEG_HEAD:
NAME: "MyNewHead"
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.