Is there a way to generate a DataLoader by creating X number of augmented images? The code I have currently only creates a single augmented image
class ImageDataset(data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir()
self.transform = transform
def __getitem__(self, index):
img = Image.open(os.path.join(self.root, self.img_names[index])).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.img_names)
Also, I would like to add labels where augmented images from the same image have the same label
Here's a couple equivalent ways you could do this.
We can change the dataset class itself to provide the same data multiple times. This can be accomplished by reporting a longer length and selecting the image name using index (mod length).
class ImageDataset(data.Dataset):
def __init__(self, root_dir, repetitions=1, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir()
self.transform = transform
self.repetitions = repetitions
def __getitem__(self, index):
img = Image.open(os.path.join(self.root,
self.img_names[index % len(self.img_names)])).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.img_names) * self.repetitions
Alternatively we could use a torch.utils.data.Subset
and specify the same indices multiple times.
# using your original implementation for ImageDataset
dataset = ImageDataset(root, transforms)
dataset = torch.utils.data.Subset(dataset, list(range(len(dataset))) * repetitions)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.