[英]How set custom class weights in Detectron2
我將https://detectron2.readthedocs.io/tutorials/install.html用於我的數據集以及其他類和對象。
而且我的數據集是不平衡的。 我希望為每個班級設置不同的權重。 我怎樣才能做到這一點?
不幸的是,如果不編寫自己的組件,還沒有辦法配置它。
一種快速執行此操作的方法是編寫一個新的頭部,該頭部繼承自包含您的損失的頭部。 然后,損失將被替換為使用您的損失權重初始化的新損失對象。
例如在DeepLabV3+的情況下,這看起來像這樣:
import torch
from torch import nn
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.projects.deeplab import DeepLabCE, DeepLabV3PlusHead
@SEM_SEG_HEADS_REGISTRY.register()
class MyNewHead(DeepLabV3PlusHead):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
weight = torch.Tensor([0.4, 0.6]) # Adapt to your case
if self.loss_type == "cross_entropy":
self.loss = nn.CrossEntropyLoss(
reduction="mean", ignore_index=self.ignore_value, weight=weight
)
elif self.loss_type == "hard_pixel_mining":
self.loss = DeepLabCE(
ignore_label=self.ignore_value,
top_k_percent_pixels=0.2,
weight=weight,
)
else:
raise ValueError("Unexpected loss type: %s" % self.loss_type)
然后,您修改配置文件以選擇新頭:
MODEL:
SEM_SEG_HEAD:
NAME: "MyNewHead"
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.