简体   繁体   中英

Augmentation of a tf.data.Dataset

Following this guide here i stumbled across this:In order to augment a tf.data Dataset we manualy use the map function to map image transformations in each image of our original dataset:

def convert(image, label):
  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
  return image, label

def augment(image,label):
  image,label = convert(image, label)
  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
  image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding
  image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28
  image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness

  return image,label

BATCH_SIZE = 64
# Only use a subset of the data so it's easier to overfit, for this tutorial
NUM_EXAMPLES = 2048

augmented_train_batches = (
    train_dataset
    # Only train on a subset, so you can quickly see the effect.
    .take(NUM_EXAMPLES)
    .cache()
    .shuffle(num_train_examples//4)
    # The augmentation is added here.
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
) 

From what i can understand what this does is this:It takes the original train_dataset and creates a new augmented_train_batches dataset which has the same number of images altered by maps transformations.After that what this does is feeding this dataset into .fit like this:

model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches)

So what i can't seem to grasp is this:Aren't the data supposed to be altered after every epoch so that(according to documentation)our model won't see the same image more than once and moreover make our overfitting chances lower?

In this tutorial isn't augmented_train_batches just a slightly altered dataset which is fed over and over to our model?

Or is the augmentaion somehow being applied after each epoch in a way i can't understand?

PSI suppose augmentation(if done correctly) must alter the pre-transformed data in a same manner after every epoch and not keep applying transformations to the same altered dataset.

is the augmentaion somehow being applied after each epoch in a way i can't understand?

No, In this tutorial the augmentation is only done once and not on every epoch. When we want to use Data Augmentation to train the network that generates augmented data per epoch, it is easier to use TF Keras Image Data Generator to generate it. This creates an iterator which you can feed the model directly the augmented data. You can read more about it in this link .

The tutorial just introduces you to the basic concept and the benefits of Data Augmentation.

And do take note this part on the tutorial:

BATCH_SIZE = 64
# Only use a subset of the data so it's easier to overfit, for this tutorial
NUM_EXAMPLES = 2048

It was intended only in this tutorial to use a subset of data that's why it is much more prone to overfit so this might the reason you are worried that the overfitting chances are higher.

Augmentation is to get more data, we just need to make minor alterations to our existing dataset. Minor changes such as flips or translations or rotations where you can do using the tf.image and applying it into each item in the dataset using the map method .map() . Our neural network would think these are distinct images anyway.

From the tutorial, training the non-augmented data separately from training the augmented data is just to compare and show how small the difference is.

In this example the augmented model converges to an accuracy ~95% on the validation set. This is slightly higher (+1%) than the model trained without data augmentation.

We can see clearly that there isn't a huge difference between the two. But normally augmentation is used for its purpose which is to provide more altered data into your dataset so the result may provide a bigger difference if you'll combine it to your original dataset and increase the number of the epoch.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM