簡體   English   中英

如何在 Detectron2 中設置自定義類權重

[英]How set custom class weights in Detectron2

我將https://detectron2.readthedocs.io/tutorials/install.html用於我的數據集以及其他類和對象。

而且我的數據集是不平衡的。 我希望為每個班級設置不同的權重。 我怎樣才能做到這一點?

不幸的是,如果不編寫自己的組件,還沒有辦法配置它。

一種快速執行此操作的方法是編寫一個新的頭部,該頭部繼承自包含您的損失的頭部。 然后,損失將被替換為使用您的損失權重初始化的新損失對象。

例如在DeepLabV3+的情況下,這看起來像這樣:

import torch
from torch import nn
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.projects.deeplab import DeepLabCE, DeepLabV3PlusHead

@SEM_SEG_HEADS_REGISTRY.register()
class MyNewHead(DeepLabV3PlusHead):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        weight = torch.Tensor([0.4, 0.6])  # Adapt to your case
        if self.loss_type == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(
                reduction="mean", ignore_index=self.ignore_value, weight=weight
            )
        elif self.loss_type == "hard_pixel_mining":
            self.loss = DeepLabCE(
                ignore_label=self.ignore_value,
                top_k_percent_pixels=0.2,
                weight=weight,
            )
        else:
            raise ValueError("Unexpected loss type: %s" % self.loss_type)

然后,您修改配置文件以選擇新頭:

MODEL:  
  SEM_SEG_HEAD:
    NAME: "MyNewHead"

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM