简体   繁体   English

TensorFlow Python - 有没有办法将 tensorflow_datasets 数据集插入 ImageGenerator?

[英]TensorFlow Python - Is there a way to insert a tensorflow_datasets dataset into ImageGenerator?

I have been trying for a long time to use data augmentation in the tensorflow.keras.preprocessing.image.ImageGenerator function, but every example I have seen passes in a directory with the files.我一直在尝试在 tensorflow.keras.preprocessing.image.ImageGenerator function 中使用数据增强,但我看到的每个示例都在包含文件的目录中通过。 My goal is to use tensorflow_datasets to import MNIST and then pass that to the data augmentation function, but I haven't been able to find out how.我的目标是使用 tensorflow_datasets 导入 MNIST,然后将其传递给数据增强 function,但我无法找到方法。

I am willing to use a directory if it is easier and if anyone can find an easy way to do that, and successfully explain to me how to do that.如果目录更容易,并且如果有人能找到一种简单的方法来做到这一点,我愿意使用目录,并成功地向我解释如何做到这一点。

See code below请参阅下面的代码

Thank you, Max谢谢你,马克斯

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def main():
    data, info = tfds.load("mnist", with_info=True)
    train_data, test_data = data['train'], data['test']

    image_gen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        fill_mode='nearest')

    #
    # What Do I Do Here??
    #

    # train_data_gen = image_gen.flow(data)


if __name__ == "__main__":
    main()


I've spent a large part of this morning trying to figure out what--if either of these methods--is the "correct" way to handle data loading AND augmentation in TF2.今天早上我花了很大一部分时间试图弄清楚在 TF2 中处理数据加载扩充的“正确”方法是什么——如果这些方法中的任何一种是正确的。 So much for "TensorFlow 2.0 removes redundant APIs" , right? “TensorFlow 2.0 删除了多余的 API”就这么多,对吧? My current understanding is that these two data loading methods are independent and meant to be used separately (though it would be great if someone from TF could chime in here).我目前的理解是这两种数据加载方法是独立的,并且可以单独使用(尽管如果来自 TF 的人可以在这里插话,那就太好了)。

First, ImageDataGenerator generates batches of tensor image data with real-time data augmentation.首先, ImageDataGenerator通过实时数据增强生成批量张量图像数据。 I've seen some solutions that use tfds to read a dataset into a numpy array an ImageDataGenerator , but I'd be shocked if that wasn't an anti-pattern.我见过一些使用tfds将数据集读入 numpy 数组和ImageDataGenerator的解决方案,但如果这不是反模式,我会感到震惊。 My understanding is that if you use ImageDataGenerator , you should be using it to both load and preprocess your data.我的理解是,如果您使用ImageDataGenerator ,您应该使用它来加载和预处理您的数据。

I've opted to go with the official(maybe?!) tensorflow_datasets route.我选择了 go 与官方(也许?!) tensorflow_datasets路线。 Rather than leverage the built-in augmentations in ImageDataGenerator , I use tfds.load to load a dataset and then a combination of caching and map calls to do my preprocessing.我没有利用ImageDataGenerator中的内置增强功能,而是使用tfds.load加载数据集,然后结合缓存和map调用来进行预处理。

First we load our data using the S3 API:首先,我们使用 S3 API 加载数据:

train_size = 45000
val_size = 5000
(train, val, test), info = tfds.load(
    "cifar10:3.*.*",
    as_supervised=True,
    split=[
        "train[0:{}]".format(train_size),
        "train[{}:{}]".format(train_size, train_size + val_size),
        "test",
    ],
    with_info=True
)

Then, using a series of helpers and tf.image functions, we can do preprocessing:然后,使用一系列辅助函数和tf.image函数,我们可以进行预处理:

def _pad_image(
    image: tf.Tensor, label: tf.Tensor, padding: int = 4
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Pads and image and returns a given supervised training pair."""
    image = tf.pad(
        image, [[padding, padding], [padding, padding], [0, 0]], mode="CONSTANT"
    )
    return image, label


def _crop_image(
    image: tf.Tensor, label: tf.Tensor, out_size: List[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly crops an image and returns a given supervised training pair."""
    if out_size is None:
        out_size = [32, 32, 3]
    image = tf.image.random_crop(image, out_size)
    return image, label


def _random_flip(
    image: tf.Tensor, label: tf.Tensor, prob=0.5
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly flips an image and returns a given supervised training pair."""
    if tf.random.uniform(()) > 1 - prob:
        image = tf.image.flip_left_right(image)
    return image, label


processed_train = (
    train.map(_pad_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .cache()
    .map(_crop_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(_random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)

Since we have a tf.data.Dataset , we can then do batching, repeat, etc. using the standard API.由于我们有一个tf.data.Dataset ,我们可以使用标准的 API 进行批处理、重复等。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何在 tensorflow_datasets 数据集上进行训练 - How to train on a tensorflow_datasets dataset 从 tensorflow_datasets 加载 mnist 数据集的问题 - Issues loading mnist dataset from tensorflow_datasets 如何在 tensorflow_datasets 加载的数据集中分别加载图像和标签 - How to load images and labels seperately in a dataset loaded by tensorflow_datasets tensorflow_datasets 版本兼容 tensorflow 1.15 - tensorflow_datasets version compatible with tensorflow 1.15 DLL 加载失败:找不到指定的模块,在 python 中导入 tensorflow_datasets - DLL load failed: The specified module could not be found, importing tensorflow_datasets in python 加载 tensorflow_datasets 时出错:谷歌云出错 - Error loading tensorflow_datasets: error with Google cloud ModuleNotFoundError:没有名为“tensorflow_datasets”的模块。 如何解决这个问题? - ModuleNotFoundError: No module named 'tensorflow_datasets'. How to solve this? ImportError:无法从“tensorflow_datasets”导入名称“testing” - ImportError: cannot import name 'testing' from 'tensorflow_datasets' 如何使用 tensorflow_datasets (tfds) 实现和理解预处理和数据扩充? - How to implement and understand Pre-processing and Data augmentation with tensorflow_datasets (tfds)? 如何加速将张量转换为 tensorflow_datasets 中 numpy 数组的代码? - How to accelerate the code which convert tensor to numpy array in tensorflow_datasets?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM