简体   繁体   中英

How is data augmentation done in each epoch?

I'm new to PyTorch and want to apply data augmentation to the datasets on each epoch. I

train_transform = Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0, 0, 0], [1, 1, 1])
])

test_transform = Compose([
    transforms.ToTensor(),
    transforms.Normalize([0, 0, 0], [1, 1, 1])
])

cifar10_train = CIFAR10(root = "/data", train=True, download = True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(cifar10_train, batch_size=128, shuffle=True)

cifar10_test = CIFAR10(root = "/data", train=False, download = True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(cifar10_test, batch_size=128, shuffle=True)

I got the code from an online tutorial. So from what I understand train_transform and test_transform is the augmentation code while cifar10_train and cifar10_test are where data is loaded and augmentation is done at the same time. Does this mean data augmentation is only done once before training? What if I want to do data augmentation for each epoch.

Following is a rough flow of operations:

  1. Read 128 examples from the filesystem.
  2. Batch them (ie make a batch of examples) and apply the transformations on the batch.
  3. Pass the batch to the.network.

Hence, the transformations are applied to each batch before feeding it to the model, irrespective of the epoch.

I think you have some misunderstandings in your code. The cifar10_train and cifar10_test actually load the dataset into python (this data is not augmented and is the raw data), then the data goes through the transforms. In most cases, the training set is where the data augmentation is done, and the testing set is not augmented because it is supposed to replicate real-world data. The transforms ( train_transform and test_transforms ) are what decide how the data is augmented, normalized, and converted into PyTorch Tensors, you can think of it as a set of guidelines/rules for the dataset to follow. As mentioned before, the training set only gets augmented, which is why the train_transform has RandomHorizontalFlip and RandomCrop (which does augmentation), and why test_transforms does not have RandomHorizontalFlip and RandomCrop . The loaders ( train_loader and test_loader ) is what splits the data into batches (groups of data), and applies the transforms to the cifar10 dataset.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM