[英]TensorFlow Python - Is there a way to insert a tensorflow_datasets dataset into ImageGenerator?
I have been trying for a long time to use data augmentation in the tensorflow.keras.preprocessing.image.ImageGenerator function, but every example I have seen passes in a directory with the files.我一直在尝试在 tensorflow.keras.preprocessing.image.ImageGenerator function 中使用数据增强,但我看到的每个示例都在包含文件的目录中通过。 My goal is to use tensorflow_datasets to import MNIST and then pass that to the data augmentation function, but I haven't been able to find out how.我的目标是使用 tensorflow_datasets 导入 MNIST,然后将其传递给数据增强 function,但我无法找到方法。
I am willing to use a directory if it is easier and if anyone can find an easy way to do that, and successfully explain to me how to do that.如果目录更容易,并且如果有人能找到一种简单的方法来做到这一点,我愿意使用目录,并成功地向我解释如何做到这一点。
See code below请参阅下面的代码
Thank you, Max谢谢你,马克斯
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def main():
data, info = tfds.load("mnist", with_info=True)
train_data, test_data = data['train'], data['test']
image_gen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
fill_mode='nearest')
#
# What Do I Do Here??
#
# train_data_gen = image_gen.flow(data)
if __name__ == "__main__":
main()
I've spent a large part of this morning trying to figure out what--if either of these methods--is the "correct" way to handle data loading AND augmentation in TF2.今天早上我花了很大一部分时间试图弄清楚在 TF2 中处理数据加载和扩充的“正确”方法是什么——如果这些方法中的任何一种是正确的。 So much for "TensorFlow 2.0 removes redundant APIs" , right? “TensorFlow 2.0 删除了多余的 API”就这么多,对吧? My current understanding is that these two data loading methods are independent and meant to be used separately (though it would be great if someone from TF could chime in here).我目前的理解是这两种数据加载方法是独立的,并且可以单独使用(尽管如果来自 TF 的人可以在这里插话,那就太好了)。
First, ImageDataGenerator
generates batches of tensor image data with real-time data augmentation.首先, ImageDataGenerator
通过实时数据增强生成批量张量图像数据。 I've seen some solutions that use tfds
to read a dataset into a numpy array an ImageDataGenerator
, but I'd be shocked if that wasn't an anti-pattern.我见过一些使用tfds
将数据集读入 numpy 数组和ImageDataGenerator
的解决方案,但如果这不是反模式,我会感到震惊。 My understanding is that if you use ImageDataGenerator
, you should be using it to both load and preprocess your data.我的理解是,如果您使用ImageDataGenerator
,您应该使用它来加载和预处理您的数据。
I've opted to go with the official(maybe?!) tensorflow_datasets
route.我选择了 go 与官方(也许?!) tensorflow_datasets
路线。 Rather than leverage the built-in augmentations in ImageDataGenerator
, I use tfds.load
to load a dataset and then a combination of caching and map
calls to do my preprocessing.我没有利用ImageDataGenerator
中的内置增强功能,而是使用tfds.load
加载数据集,然后结合缓存和map
调用来进行预处理。
First we load our data using the S3 API:首先,我们使用 S3 API 加载数据:
train_size = 45000
val_size = 5000
(train, val, test), info = tfds.load(
"cifar10:3.*.*",
as_supervised=True,
split=[
"train[0:{}]".format(train_size),
"train[{}:{}]".format(train_size, train_size + val_size),
"test",
],
with_info=True
)
Then, using a series of helpers and tf.image
functions, we can do preprocessing:然后,使用一系列辅助函数和tf.image
函数,我们可以进行预处理:
def _pad_image(
image: tf.Tensor, label: tf.Tensor, padding: int = 4
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Pads and image and returns a given supervised training pair."""
image = tf.pad(
image, [[padding, padding], [padding, padding], [0, 0]], mode="CONSTANT"
)
return image, label
def _crop_image(
image: tf.Tensor, label: tf.Tensor, out_size: List[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Randomly crops an image and returns a given supervised training pair."""
if out_size is None:
out_size = [32, 32, 3]
image = tf.image.random_crop(image, out_size)
return image, label
def _random_flip(
image: tf.Tensor, label: tf.Tensor, prob=0.5
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Randomly flips an image and returns a given supervised training pair."""
if tf.random.uniform(()) > 1 - prob:
image = tf.image.flip_left_right(image)
return image, label
processed_train = (
train.map(_pad_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.cache()
.map(_crop_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.map(_random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)
Since we have a tf.data.Dataset
, we can then do batching, repeat, etc. using the standard API.由于我们有一个tf.data.Dataset
,我们可以使用标准的 API 进行批处理、重复等。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.