简体   繁体   中英

What happend after data augmentation done?

I use Kaggle's "Dogs Vs cats" date set , and follow the TensorFlow 's cifar-10 tutorial (I did not use weight decay, moving average and L2 loss for convenient), I have trained my network successful, but when I added the data augmentation part to my code, strange things just happened, the loss never went down even though after thousand of steps(before added, every thing just ok). Code showed below:

def get_batch(image, label, image_w, image_h, batch_size, capacity, test_flag=False):
  '''
  Args:
      image: list type
      label: list type
      image_w: image width
      image_h: image height
      batch_size: batch size
      capacity: the maximum elements in queue 
      test_flag: create training batch or test batch
  Returns:
      image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
      label_batch: 1D tensor [batch_size], dtype=tf.int32
  '''

  image = tf.cast(image, tf.string)
  label = tf.cast(label, tf.int32)

  # make an input queue
  input_queue = tf.train.slice_input_producer([image, label])

  label = input_queue[1]
  image_contents = tf.read_file(input_queue[0])
  image = tf.image.decode_jpeg(image_contents, channels=3)

  ####################################################################
  # Data argumentation should go to here
  # but when we want to do test, stay the images what they are

  if not test_flag:
     image = tf.image.resize_image_with_crop_or_pad(image, RESIZED_IMG, RESIZED_IMG)
     # Randomly crop a [height, width] section of the image.
     distorted_image = tf.random_crop(image, [image_w, image_h, 3])

    # Randomly flip the image horizontally.
     distorted_image = tf.image.random_flip_left_right(distorted_image)

    # Because these operations are not commutative, consider randomizing
    # the order their operation.
    # NOTE: since per_image_standardization zeros the mean and makes
    # the stddev unit, this likely has no effect see tensorflow#1458.
     distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)

     image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
  else:
     image = tf.image.resize_image_with_crop_or_pad(image, image_w, image_h)

  ######################################################################

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
  # Set the shapes of tensors.
  image.set_shape([image_h, image_w, 3])
  # label.set_shape([1])

  image_batch, label_batch = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=64,
                                            capacity=capacity)

  label_batch = tf.reshape(label_batch, [batch_size])
  image_batch = tf.cast(image_batch, tf.float32)

  return image_batch, label_batch

Make sure the limits you use (eg max_delta=63 for brightness, upper=1.8 for contrast) are low enough so that an image is still recognizable. One of other problems can be that augmentation is applied over and over again, so after a few iterations it's completely distorted (though I didn't spot this bug in your snippet).

What I suggest you to do is to add a visualization of your data into tensorboard . To visualize an image, use tf.summary.image method. You'll be able to see the result of augmentation clearly.

tf.summary.image('input', image_batch, 10)

This gist can serve as an example.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM