简体   繁体   English

PyTorch - 为 object 检测训练不平衡数据集(设置权重)

[英]PyTorch - Train imbalanced dataset (set weights) for object detection

I am quite new with PyTorch, and I am trying to use an object detection model to do transfer learning in order to learn how to detect my new dataset.我对 PyTorch 很陌生,我正在尝试使用 object 检测 model 进行迁移学习,以了解如何检测我的新数据集。

Here is how I load the dataset:这是我加载数据集的方式:

train_dataset = MyDataset(train_data_path, 512, 512, train_labels_path, get_train_transform())
train_loader = DataLoader(train_dataset,batch_size=8,shuffle=True,num_workers=4,collate_fn=collate_fn)
valid_dataset = MyDataset(test_data_path, 512, 512, test_labels_path, get_valid_transform())
valid_loader = DataLoader(valid_dataset,batch_size=8, shuffle=False,num_workers=4,collate_fn=collate_fn)

I define the model and optimizer as follows:我定义 model 和优化器如下:

# load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_Weights.COCO_V1") # get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# define a new head for the detector with the required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model = model.to(DEVICE)
# get the model parameters
params = [p for p in model.parameters() if p.requires_grad]
# define the optimizer
# We are using the SGD optimizer with a learning rate of 0.001 and momentum on 0.9.
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)

I train the model as follows:我训练 model 如下:

def train(train_data_loader, model, optimizer, train_loss_hist):

    global train_itr
    global train_loss_list

    prog_bar = tqdm(train_data_loader, total=len(train_data_loader), position=0, leave=True, ascii=True)

    # Then we have the for loop iterating over the batches.

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data

        images = list(image.to(DEVICE) for image in images)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)

        # Then we sum the losses and append the current iterations loss value to train_loss_list list.
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        # We also send the current loss value to train_loss_hist of the Averager class.
        train_loss_list.append(loss_value)
        train_loss_hist.send(loss_value)

        # Then we backpropagate the gradients and update parameters.
        losses.backward()
        optimizer.step()
        train_itr += 1
    return train_loss_list

Considering that I adapted one code I found and I am not sure where the loss is defined (I have not defined any kind of loss in the code, so I believe it will use the default loss that was used to train the original object detector), how can I train my network considering such an imbalanced dataset and update my code?考虑到我修改了我找到的一个代码,但我不确定在哪里定义了损失(我没有在代码中定义任何类型的损失,所以我相信它将使用用于训练原始 object 检测器的默认损失) ,考虑到这样一个不平衡的数据集,我该如何训练我的网络并更新我的代码?

It seems that you have two questions.看来你有两个问题。

  1. How to deal with imbalanced dataset.如何处理不平衡的数据集。 Note that Faster-RCNN is an Anchor-Based detector, which means number of anchors containing the object is extremely small compared to the number of total anchors, so you don't need to deal with the imbalanced dataset.请注意,Faster-RCNN 是一个基于锚的检测器,这意味着包含 object 的锚的数量与总锚的数量相比非常少,因此您不需要处理不平衡的数据集。 Or you can use RetinaNet which proposed a loss function called focal loss to improve performance upon imbalanced dataset.或者您可以使用 RetinaNet,它提出了一种称为焦点损失的损失 function 来提高不平衡数据集的性能。
  2. Where is the loss function.损失 function 在哪里。 torchvision integrated the loss function inside the model object, you can debug your python code step by step inside the torchvision package and see the implementation details torchvision integrated the loss function inside the model object, you can debug your python code step by step inside the torchvision package and see the implementation details

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何正确拆分不平衡数据集以训练和测试集? - How can I properly split imbalanced dataset to train and test set? 如何在 PyTorch 中使用具有焦点损失的类权重用于多类分类的不平衡数据集 - How to Use Class Weights with Focal Loss in PyTorch for Imbalanced dataset for MultiClass Classification 不平衡数据集的训练/验证/测试集 - Train/Validation/Testing sets for imbalanced dataset 用于对象检测和分割的 Mask R-CNN [训练自定义数据集] - Mask R-CNN for object detection and segmentation [Train for a custom dataset] 在TensorFlow或PyTorch中仅创建和训练指定的权重 - Creating and train only specified weights in TensorFlow or PyTorch 为卷积神经网络中的不平衡数据集添加类权重 - Adding Class Weights for imbalanced dataset in Convolutional Neural Network 正确的数据加载器设置以训练 fastrcnn-resnet50 以使用 pytorch 进行对象检测 - Proper dataloader setup to train fasterrcnn-resnet50 for object detection with pytorch 使用MNIST数据集Pytorch训练SqueezeNet模型 - Train SqueezeNet model using MNIST dataset Pytorch 不平衡 CNN 上的类权重 - Class weights on imbalanced CNN TensorFlow Object Detection API - 如何在 COCO 数据集上进行训练并获得与报告的相同的 mAP? - TensorFlow Object Detection API - How to train on COCO dataset and achieve same mAP as the reported one?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM