Hello I am learning about GANS and deep learning, and in general, when I worked with that I did with NumPy arrays of images, but for this homework, I get the data with tfds of the following way:
test_split, valid_split, train_split = tfds.Split.TRAIN.subsplit([10, 15, 75])
test_set_raw = tfds.load('cats_vs_dogs', split=test_split, as_supervised=True)
valid_set_raw = tfds.load('cats_vs_dogs', split=valid_split, as_supervised=True)
train_set_raw = tfds.load('cats_vs_dogs', split=train_split, as_supervised=True)
The problem is that I want to do data augmentation with a few of this examples but I cannot access to every image on these _OptionsDataset
only with take(), but I want to iterate this makes data augmentation for every image and add this news images.
I could do this with NumPy and two arrays, but I don't have an idea how can do this with _OptionsDataset
.
Is this possible? How can I do?, is it possible to convert _OptionsDataset
to NumPy array and convert the NumPy array again to _OptionsDataset?
Thanks
tf.image
has a bunch of random transformations you can use, you don't need Numpy. Here's an example. I had to select the splits a little differently since I have another version. Here is the documentation for tf.image
.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow_datasets as tfds
[train_set_raw] = tfds.load('cats_vs_dogs', split=['train[:100]'], as_supervised=True)
def augment(tensor):
tensor = tf.cast(x=tensor, dtype=tf.float32)
tensor = tf.image.rgb_to_grayscale(images=tensor)
tensor = tf.image.resize(images=tensor, size=(96, 96))
tensor = tf.divide(x=tensor, y=tf.constant(255.))
tensor = tf.image.random_flip_left_right(image=tensor)
tensor = tf.image.random_brightness(image=tensor, max_delta=2e-1)
tensor = tf.image.random_crop(value=tensor, size=(64, 64, 1))
return tensor
train_set_raw = train_set_raw.shuffle(128).map(lambda x, y: (augment(x), y)).batch(16)
import matplotlib.pyplot as plt
plt.imshow((next(iter(train_set_raw))[0][0][..., 0].numpy()*255).astype(int))
plt.show()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.