简体   繁体   中英

TensorFlow Python - Is there a way to insert a tensorflow_datasets dataset into ImageGenerator?

I have been trying for a long time to use data augmentation in the tensorflow.keras.preprocessing.image.ImageGenerator function, but every example I have seen passes in a directory with the files. My goal is to use tensorflow_datasets to import MNIST and then pass that to the data augmentation function, but I haven't been able to find out how.

I am willing to use a directory if it is easier and if anyone can find an easy way to do that, and successfully explain to me how to do that.

See code below

Thank you, Max

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def main():
    data, info = tfds.load("mnist", with_info=True)
    train_data, test_data = data['train'], data['test']

    image_gen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        fill_mode='nearest')

    #
    # What Do I Do Here??
    #

    # train_data_gen = image_gen.flow(data)


if __name__ == "__main__":
    main()


I've spent a large part of this morning trying to figure out what--if either of these methods--is the "correct" way to handle data loading AND augmentation in TF2. So much for "TensorFlow 2.0 removes redundant APIs" , right? My current understanding is that these two data loading methods are independent and meant to be used separately (though it would be great if someone from TF could chime in here).

First, ImageDataGenerator generates batches of tensor image data with real-time data augmentation. I've seen some solutions that use tfds to read a dataset into a numpy array an ImageDataGenerator , but I'd be shocked if that wasn't an anti-pattern. My understanding is that if you use ImageDataGenerator , you should be using it to both load and preprocess your data.

I've opted to go with the official(maybe?!) tensorflow_datasets route. Rather than leverage the built-in augmentations in ImageDataGenerator , I use tfds.load to load a dataset and then a combination of caching and map calls to do my preprocessing.

First we load our data using the S3 API:

train_size = 45000
val_size = 5000
(train, val, test), info = tfds.load(
    "cifar10:3.*.*",
    as_supervised=True,
    split=[
        "train[0:{}]".format(train_size),
        "train[{}:{}]".format(train_size, train_size + val_size),
        "test",
    ],
    with_info=True
)

Then, using a series of helpers and tf.image functions, we can do preprocessing:

def _pad_image(
    image: tf.Tensor, label: tf.Tensor, padding: int = 4
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Pads and image and returns a given supervised training pair."""
    image = tf.pad(
        image, [[padding, padding], [padding, padding], [0, 0]], mode="CONSTANT"
    )
    return image, label


def _crop_image(
    image: tf.Tensor, label: tf.Tensor, out_size: List[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly crops an image and returns a given supervised training pair."""
    if out_size is None:
        out_size = [32, 32, 3]
    image = tf.image.random_crop(image, out_size)
    return image, label


def _random_flip(
    image: tf.Tensor, label: tf.Tensor, prob=0.5
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly flips an image and returns a given supervised training pair."""
    if tf.random.uniform(()) > 1 - prob:
        image = tf.image.flip_left_right(image)
    return image, label


processed_train = (
    train.map(_pad_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .cache()
    .map(_crop_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(_random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)

Since we have a tf.data.Dataset , we can then do batching, repeat, etc. using the standard API.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM