简体   繁体   中英

PyTorch for Object detection - Image augmentation

I am using PyTorch for object detection and refining an existing model (transfer learning) as described in the following link - https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

While different transformations are used for image augmentation (horizontal flip in this tutorial), the tutorial doesnt mention anything on transforming the bounding box/annotation to ensure they are in line with the transformed image. Am I missing something basic here?

In the training phase, the transforms are indeed applied on both images and targets, while loading the data. In the PennFudanDataset class, we have these two lines:

if self.transforms is not None:  
    img, target = self.transforms(img, target)

where target is a dictionary containing:

target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

self.transforms() in PennFudanDataset class is set to a list of transforms comprising [transforms.ToTensor(), transforms.Compose()] , the return value from get_transform() while instantiating the dataset with:

dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))

The transforms transforms.Compose() comes from T , a custom transform written for object detection task. Specifically, in the __call__ of RandomHorizontalFlip() , we process both the image and target (eg, mask, keypoints):

For the sake of completeness, I borrow the code from the github repo:

def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
            if "keypoints" in target:
                keypoints = target["keypoints"]
                keypoints = _flip_coco_person_keypoints(keypoints, width)
                target["keypoints"] = keypoints
        return image, target

Here, we can understand how they perform the flipping on the masks and keypoints in accordance with the image.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM