简体   繁体   中英

How can I do data augmentation with _OptionsDataset?

Hello I am learning about GANS and deep learning, and in general, when I worked with that I did with NumPy arrays of images, but for this homework, I get the data with tfds of the following way:

test_split, valid_split, train_split = tfds.Split.TRAIN.subsplit([10, 15, 75])
test_set_raw = tfds.load('cats_vs_dogs', split=test_split,  as_supervised=True)
valid_set_raw = tfds.load('cats_vs_dogs', split=valid_split,  as_supervised=True)
train_set_raw = tfds.load('cats_vs_dogs', split=train_split,  as_supervised=True)

The problem is that I want to do data augmentation with a few of this examples but I cannot access to every image on these _OptionsDataset only with take(), but I want to iterate this makes data augmentation for every image and add this news images.

I could do this with NumPy and two arrays, but I don't have an idea how can do this with _OptionsDataset .

Is this possible? How can I do?, is it possible to convert _OptionsDataset to NumPy array and convert the NumPy array again to _OptionsDataset?

Thanks

tf.image has a bunch of random transformations you can use, you don't need Numpy. Here's an example. I had to select the splits a little differently since I have another version. Here is the documentation for tf.image .

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow_datasets as tfds

[train_set_raw] = tfds.load('cats_vs_dogs', split=['train[:100]'],  as_supervised=True)


def augment(tensor):
    tensor = tf.cast(x=tensor, dtype=tf.float32)
    tensor = tf.image.rgb_to_grayscale(images=tensor)
    tensor = tf.image.resize(images=tensor, size=(96, 96))
    tensor = tf.divide(x=tensor, y=tf.constant(255.))
    tensor = tf.image.random_flip_left_right(image=tensor)
    tensor = tf.image.random_brightness(image=tensor, max_delta=2e-1)
    tensor = tf.image.random_crop(value=tensor, size=(64, 64, 1))
    return tensor


train_set_raw = train_set_raw.shuffle(128).map(lambda x, y: (augment(x), y)).batch(16)

import matplotlib.pyplot as plt

plt.imshow((next(iter(train_set_raw))[0][0][..., 0].numpy()*255).astype(int))
plt.show()

在此处输入图片说明 在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM