简体   繁体   中英

Pytorch creating dataset with augmented images

Is there a way to generate a DataLoader by creating X number of augmented images? The code I have currently only creates a single augmented image

class ImageDataset(data.Dataset):
    
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.img_names = os.listdir()
        self.transform = transform
        
    def __getitem__(self, index):
        
        img = Image.open(os.path.join(self.root, self.img_names[index])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
            
        return img
            
    def __len__(self):
        
        return len(self.img_names)

Also, I would like to add labels where augmented images from the same image have the same label

Here's a couple equivalent ways you could do this.

Option 1

We can change the dataset class itself to provide the same data multiple times. This can be accomplished by reporting a longer length and selecting the image name using index (mod length).

class ImageDataset(data.Dataset):
    def __init__(self, root_dir, repetitions=1, transform=None):
        self.root_dir = root_dir
        self.img_names = os.listdir()
        self.transform = transform
        self.repetitions = repetitions
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root,
            self.img_names[index % len(self.img_names)])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img
            
    def __len__(self):
        return len(self.img_names) * self.repetitions

Option 2

Alternatively we could use a torch.utils.data.Subset and specify the same indices multiple times.

# using your original implementation for ImageDataset
dataset = ImageDataset(root, transforms)
dataset = torch.utils.data.Subset(dataset, list(range(len(dataset))) * repetitions)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM