简体   繁体   中英

How do I correctly apply data augmentation to a TFRecord Dataset?

I am attempting to apply data augmentation to a TFRecord dataset after it has been parsed. However, when I check the size of the dataset before and after mapping the augmentation function, the sizes are the same. I know the parse function is working and the datasets are correct as I have already used them to train a model. So I have only included code to map the function and count the examples afterward.

Here is the code I am using:

num_ex = 0

def flip_example(image, label):
    flipped_image = flip(image)
    return flipped_image, label


dataset = tf.data.TFRecordDataset('train.TFRecord').map(parse_function)
for x in dataset:
    num_ex += 1

num_ex = 0
dataset = dataset.map(flip_example)

#Size of dataset
for x in dataset:
    num_ex += 1

In both cases, num_ex = 324 instead of the expected 324 for non-augmented and 648 for augmented. I have also successfully tested the flip function so it seems the issue is with how the function interacts with the dataset. How do I correctly implement this augmentation?

When you apply data augmentation with the tf.data API, it is done on-the-fly , meaning that every example is transformed as implemented in your method. Augmenting data this way does not mean that the number of examples in your pipeline changes.

If you want to use every example n times, simply add dataset = dataset.repeat(count=n) . You might want to update your code to use tf.image.random_flip_left_right , otherwise the flip is done the same way each time.

In your example the second time you check num_ex, dataset only contains the flipped images so 324. Furthermore if you have a large dataset, larger than 324, you might want to look into online data augmentation. In this case, during training the dataset is augmented differently every epoch, and you only train on the augmented data not on the original dataset. This helps the trained model generalise better. ( https://www.tensorflow.org/tutorials/images/data_augmentation )

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM