简体   繁体   中英

Tensorflow CNN image augmentation pipeline

I'm trying to learn the new Tensorflow APIs and I am a bit lost on where to get a handle on my input batch tensors so I can manipulate and augment them with for example tf.image.

This is the my current network & pipeline:

trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays

#...
train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))

#...
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)

#...defining cnn architecture...

# In the train loop
TrainLoop {
   sess.run(train_init_op)  # switching to train data
   sess.run(train_step, ...) # running a train step

   #... 
   sess.run(test_init_op)  # switching to test data
   test_loss = sess.run(loss, ...) # printing test loss after epoch
}

I'm using the Dataset API creating 2 datasets so that in the trainloop I can calculate the train and test loss and log them.

Where in this pipeline would I manipulate and distort my input batch of images? I'm not creating any tf.placeholders for my trainX input batches so I can't manipulate them with tf.image because for example tf.image.flip_up_down requires a 3-D or 4-D tensor.

  • What is the natural way to implement this pipeline with the new API?
  • Is there a module or easy way to augment an input batch of images for training that would fit in this pipeline?

There's a really good article and talk released recently that go over the API in a lot more detail than my response here. Here's a brief example:

import tensorflow as tf
import numpy as np


def read_data():
    n_train = 100
    n_test = 50
    height = 20
    width = 30
    channels = 3
    trainX = (np.random.random(
        size=(n_train, height, width, channels)) * 255).astype(np.uint8)
    testX = (np.random.random(
            size=(n_test, height, width, channels))*255).astype(np.uint8)
    trainY = (np.random.random(size=(n_train,))*10).astype(np.int32)
    testY = (np.random.random(size=(n_test,))*10).astype(np.int32)
    return trainX, testX, trainY, testY


trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays


train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))


def map_single(x, y):
    print('Map single:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    x = tf.image.per_image_standardization(x)
    # Consider: x = tf.image.random_flip_left_right(x)
    return x, y


def map_batch(x, y):
    print('Map batch:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    # Note: this flips ALL images left to right. Not sure this is what you want
    # UPDATE: looks like tf documentation is wrong and you need a 3D tensor?
    # return tf.image.flip_left_right(x), y
    return x, y


batch_size = 32
train_dataset = train_dataset.repeat().shuffle(100)
train_dataset = train_dataset.map(map_single, num_parallel_calls=8)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.map(map_batch)
train_dataset = train_dataset.prefetch(2)

test_dataset = test_dataset.map(
        map_single, num_parallel_calls=8).batch(batch_size).map(map_batch)
test_dataset = test_dataset.prefetch(2)


iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)


with tf.Session() as sess:
    sess.run(train_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)

    sess.run(test_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)    

A few notes:

  1. This approach relies on being able to load your entire dataset into memory. If you cannot, consider using tf.data.Dataset.from_generator . This can lead to slow shuffle times if your shuffle buffer is large. My preferred method is to load some keys tensor entirely into memory - it might just be the indices of each example - then map that key value to data values using tf.py_func . This is slightly less efficient than converting to tfrecords , but with prefetching it likely won't affect performance. Since the shuffling is done before the mapping, you only have to load shuffle_buffer keys into memory, rather than shuffle_buffer examples.
  2. To augment your dataset, use tf.data.Dataset.map either before or after the batch operation, depending on whether or not you want to apply a batch-wise operation (something working on a 4D image tensor) or element-wise operation (3D image tensor). Note it looks like the documentation for tf.image.flip_left_right is out of date, since I get an error when I try and use a 4D tensor. If you want to augment you data randomly, use tf.image.random_flip_left_right rather than tf.image.flip_left_right .
  3. If you're using a tf.estimator.Estimator (or wouldn't mind converting your code to using it), then check out tf.estimator.train_and_evaluate for an in-built way of switching between datasets.
  4. Consider shuffling/repeating your dataset with the shuffle / repeat methods. See the article for notes on efficiencies. In particular, repeat -> shuffle -> map -> batch -> batch-wise map -> prefetch seems to be the best ordering of operations for most applications.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM