简体   繁体   中英

Keras ImageDataGenerator with center crop for rotation and translation shift

I need to do data augmentation but not with any fill modes, constant , reflect , nearest , wrap . Instead everytime the image is rotated or translated, I would like to have it center-cropped (shown below) so as not have any black, white, reflected, or constant edges/borders as explained here .

在此处输入图像描述

How do I extend the ImageDataGenerator class (if that's the only way to do it and no center crop is available out of the box) with these points taken into account?

  1. Keep existing parts of the ImageDataGenerator other than the augmentation part, and write a custom augmentation function

  2. It would be efficient to retain the images of original size without resizing before augmentation happens because center crop would result in huge loss of data after resize. Translate/Rotate -> Center crop -> Resize should be more efficient than Resize -> Translate/Rotate -> Center crop

In case someone is looking for a solution, here's how I managed to solve the problem. The main idea is to wrap the ImageDataGenerator in a custom generator, like this:

def crop_generator(batches, new_size):
    while True:
        batch_x, batch_y = next(batches)
        x= batch_x.shape[1] // 2
        y= batch_x.shape[2] // 2
        size = new_size // 2
        yield (batch_x[:, x-size:x+size, y-size:y+size], batch_y)

x_train = HDF5Matrix(...)
y_train = HDF5Matrix(...)

datagen = ImageDataGenerator(rotation_range=180, ...)

model = create_model()

training_gen = crop_generator(datagen.flow(x_train, y_train, batch_size=128), new_size=64)

model.fit_generator(training_gen, ...)

Using numpy indexing batch_x[:, x-size:x+size, y-size:y+size, :] we only alter x and y dimensions of the images, leaving batch size and channel dimensions the same. This allows us to avoid the for-loop.

This may be helpful,

Extending Keras' ImageDataGenerator to Support Random Cropping

https://jkjung-avt.github.io/keras-image-cropping/

github code:

https://github.com/jkjung-avt/keras-cats-dogs-tutorial/blob/master/train_cropped.py

train_datagen = ImageDataGenerator(......)
train_batches = train_datagen.flow_from_directory(DATASET_PATH + '/train',
                                              target_size=(256,256),
                                              ......)
train_crops = crop_generator(train_batches, 224)
net_final.fit_generator(train_crops, ......)

I defined a simple function to wrap a generator by changing it's __next__ function, all you need is to define a function that takes a batch of data x,y and outputs a modified version, you can use it like this:

def my_wrapper(x, y):
    return transform(x), y # for example transform here does the transformation you want
wrap_generator(train_generator, my_wrapper)

This resembles the solution of @ Aray Karjauv but it preserves the generator and it's important attributes like generator.n which are needed in showing progress bars while training.

Here's the definition of wrap_generator :

def wrap_generator(gen, wrapper, restore_original=False):
    '''this decorator wraps the generator's __next__ with a given wrapper function.
    NOTE:
        calling this multiple times won't stack up multiple wrappers, instead, the very first/original
          __next__ method is stored and wrapped with a new function each time.
        to restore the original __next__ function, call map_generator(gen,None, restore_original=True)
    WARNING:
        this function only 
    Parameters
    ----------
        gen : tensoflow.keras.preprocessing.image.DataImageGenerator
            a DataImageGenerator fully initialized ()
        wrapper: function (x,y) -> (x,y)
            a simple function that takes one batch of data (like the ones the generator produces)
            and outputs a modified version of it
        restore_original: boolean 
            if true, the original __next__ will be restored
    Returns:
        the same generator object (for convenience)
    Example of usage:
        def my_wrapper(x, y):
            return transform(x), y
        wrap_generator(train_generator, my_wrapper)
    '''
    # store the original __next__ method in the generator (if not done before)
    if not hasattr(gen, '_original_next'):
        gen._original_next = gen.__next__
    # Restore original generator if demanded
    if restore_original:
        gen.__next__ = gen._original_next
        del gen._original_next
        return gen
    # wrap original _original_next method with the wrapper function
    def fixed_next():
        x,y = gen._original_next()
        return wrapper(x,y)
    gen.__next__ = fixed_next
    return gen

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM