繁体   English   中英

PyTorch - 为 object 检测训练不平衡数据集(设置权重)

[英]PyTorch - Train imbalanced dataset (set weights) for object detection

我对 PyTorch 很陌生,我正在尝试使用 object 检测 model 进行迁移学习,以了解如何检测我的新数据集。

这是我加载数据集的方式:

train_dataset = MyDataset(train_data_path, 512, 512, train_labels_path, get_train_transform())
train_loader = DataLoader(train_dataset,batch_size=8,shuffle=True,num_workers=4,collate_fn=collate_fn)
valid_dataset = MyDataset(test_data_path, 512, 512, test_labels_path, get_valid_transform())
valid_loader = DataLoader(valid_dataset,batch_size=8, shuffle=False,num_workers=4,collate_fn=collate_fn)

我定义 model 和优化器如下:

# load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_Weights.COCO_V1") # get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# define a new head for the detector with the required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model = model.to(DEVICE)
# get the model parameters
params = [p for p in model.parameters() if p.requires_grad]
# define the optimizer
# We are using the SGD optimizer with a learning rate of 0.001 and momentum on 0.9.
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)

我训练 model 如下:

def train(train_data_loader, model, optimizer, train_loss_hist):

    global train_itr
    global train_loss_list

    prog_bar = tqdm(train_data_loader, total=len(train_data_loader), position=0, leave=True, ascii=True)

    # Then we have the for loop iterating over the batches.

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)

        # Then we sum the losses and append the current iterations loss value to train_loss_list list.
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        # We also send the current loss value to train_loss_hist of the Averager class.
        train_loss_list.append(loss_value)
        train_loss_hist.send(loss_value)

        # Then we backpropagate the gradients and update parameters.
        losses.backward()
        optimizer.step()
        train_itr += 1
    return train_loss_list

考虑到我修改了我找到的一个代码,但我不确定在哪里定义了损失(我没有在代码中定义任何类型的损失,所以我相信它将使用用于训练原始 object 检测器的默认损失) ,考虑到这样一个不平衡的数据集,我该如何训练我的网络并更新我的代码?

看来你有两个问题。

  1. 如何处理不平衡的数据集。 请注意,Faster-RCNN 是一个基于锚的检测器,这意味着包含 object 的锚的数量与总锚的数量相比非常少,因此您不需要处理不平衡的数据集。 或者您可以使用 RetinaNet,它提出了一种称为焦点损失的损失 function 来提高不平衡数据集的性能。
  2. 损失 function 在哪里。 torchvision integrated the loss function inside the model object, you can debug your python code step by step inside the torchvision package and see the implementation details

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM