繁体   English   中英

TensorFlow Python - 有没有办法将 tensorflow_datasets 数据集插入 ImageGenerator?

[英]TensorFlow Python - Is there a way to insert a tensorflow_datasets dataset into ImageGenerator?

我一直在尝试在 tensorflow.keras.preprocessing.image.ImageGenerator function 中使用数据增强,但我看到的每个示例都在包含文件的目录中通过。 我的目标是使用 tensorflow_datasets 导入 MNIST,然后将其传递给数据增强 function,但我无法找到方法。

如果目录更容易,并且如果有人能找到一种简单的方法来做到这一点,我愿意使用目录,并成功地向我解释如何做到这一点。

请参阅下面的代码

谢谢你,马克斯

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def main():
    data, info = tfds.load("mnist", with_info=True)
    train_data, test_data = data['train'], data['test']

    image_gen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        fill_mode='nearest')

    #
    # What Do I Do Here??
    #

    # train_data_gen = image_gen.flow(data)


if __name__ == "__main__":
    main()


今天早上我花了很大一部分时间试图弄清楚在 TF2 中处理数据加载扩充的“正确”方法是什么——如果这些方法中的任何一种是正确的。 “TensorFlow 2.0 删除了多余的 API”就这么多,对吧? 我目前的理解是这两种数据加载方法是独立的,并且可以单独使用(尽管如果来自 TF 的人可以在这里插话,那就太好了)。

首先, ImageDataGenerator通过实时数据增强生成批量张量图像数据。 我见过一些使用tfds将数据集读入 numpy 数组和ImageDataGenerator的解决方案,但如果这不是反模式,我会感到震惊。 我的理解是,如果您使用ImageDataGenerator ,您应该使用它来加载和预处理您的数据。

我选择了 go 与官方(也许?!) tensorflow_datasets路线。 我没有利用ImageDataGenerator中的内置增强功能,而是使用tfds.load加载数据集,然后结合缓存和map调用来进行预处理。

首先,我们使用 S3 API 加载数据:

train_size = 45000
val_size = 5000
(train, val, test), info = tfds.load(
    "cifar10:3.*.*",
    as_supervised=True,
    split=[
        "train[0:{}]".format(train_size),
        "train[{}:{}]".format(train_size, train_size + val_size),
        "test",
    ],
    with_info=True
)

然后,使用一系列辅助函数和tf.image函数,我们可以进行预处理:

def _pad_image(
    image: tf.Tensor, label: tf.Tensor, padding: int = 4
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Pads and image and returns a given supervised training pair."""
    image = tf.pad(
        image, [[padding, padding], [padding, padding], [0, 0]], mode="CONSTANT"
    )
    return image, label


def _crop_image(
    image: tf.Tensor, label: tf.Tensor, out_size: List[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly crops an image and returns a given supervised training pair."""
    if out_size is None:
        out_size = [32, 32, 3]
    image = tf.image.random_crop(image, out_size)
    return image, label


def _random_flip(
    image: tf.Tensor, label: tf.Tensor, prob=0.5
) -> Tuple[tf.Tensor, tf.Tensor]:
    """Randomly flips an image and returns a given supervised training pair."""
    if tf.random.uniform(()) > 1 - prob:
        image = tf.image.flip_left_right(image)
    return image, label


processed_train = (
    train.map(_pad_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .cache()
    .map(_crop_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(_random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)

由于我们有一个tf.data.Dataset ,我们可以使用标准的 API 进行批处理、重复等。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM