[英]PyTorch - Train imbalanced dataset (set weights) for object detection
我对 PyTorch 很陌生,我正在尝试使用 object 检测 model 进行迁移学习,以了解如何检测我的新数据集。
这是我加载数据集的方式:
train_dataset = MyDataset(train_data_path, 512, 512, train_labels_path, get_train_transform())
train_loader = DataLoader(train_dataset,batch_size=8,shuffle=True,num_workers=4,collate_fn=collate_fn)
valid_dataset = MyDataset(test_data_path, 512, 512, test_labels_path, get_valid_transform())
valid_loader = DataLoader(valid_dataset,batch_size=8, shuffle=False,num_workers=4,collate_fn=collate_fn)
我定义 model 和优化器如下:
# load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_Weights.COCO_V1") # get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# define a new head for the detector with the required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model = model.to(DEVICE)
# get the model parameters
params = [p for p in model.parameters() if p.requires_grad]
# define the optimizer
# We are using the SGD optimizer with a learning rate of 0.001 and momentum on 0.9.
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
我训练 model 如下:
def train(train_data_loader, model, optimizer, train_loss_hist):
global train_itr
global train_loss_list
prog_bar = tqdm(train_data_loader, total=len(train_data_loader), position=0, leave=True, ascii=True)
# Then we have the for loop iterating over the batches.
for i, data in enumerate(prog_bar):
optimizer.zero_grad()
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
# Forward pass
loss_dict = model(images, targets)
# Then we sum the losses and append the current iterations loss value to train_loss_list list.
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
# We also send the current loss value to train_loss_hist of the Averager class.
train_loss_list.append(loss_value)
train_loss_hist.send(loss_value)
# Then we backpropagate the gradients and update parameters.
losses.backward()
optimizer.step()
train_itr += 1
return train_loss_list
考虑到我修改了我找到的一个代码,但我不确定在哪里定义了损失(我没有在代码中定义任何类型的损失,所以我相信它将使用用于训练原始 object 检测器的默认损失) ,考虑到这样一个不平衡的数据集,我该如何训练我的网络并更新我的代码?
看来你有两个问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.