简体   繁体   English

如何使用自定义数据生成器进行 keras 图像增强?

[英]How to do keras image augmentation using custom data generator?

I am using Keras custom generator and i want to apply image augmentation techniques on data returned from custom data generator.我正在使用 Keras 自定义生成器,我想对自定义数据生成器返回的数据应用图像增强技术。

I want these image augmentation techniques我想要这些图像增强技术

ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

This is keras custom generator这是 keras 定制发电机

def __data_generation(self, list_IDs_temp):
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  # Initialization
  X = np.empty((self.batch_size, *self.dim, self.n_channels))
  y = np.empty((self.batch_size), dtype=int)

      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
    
          # Store class
          y[i] = self.labels[ID]
    
      return X, tkf.utils.to_categorical(y, num_classes=self.n_classes)

Haven't tried it but I guess you can use the flow method from your instance of ImageDataGenerator .还没有尝试过,但我想您可以使用ImageDataGenerator实例中的flow方法。 For example, your custom class could look like this:例如,您的自定义 class 可能如下所示:

class CustomDataGenerator(tf.keras.utils.Sequence):
    
    def __init__(self, batch_size=32):
        self.batch_size = batch_size
        self.augmentor = ImageDataGenerator(
            rotation_range=40,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )

    ...

    def __data_generation(self, list_IDs_temp):
      'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
      # Initialization
      X = np.empty((self.batch_size, *self.dim, self.n_channels))
      y = np.empty((self.batch_size), dtype=int)

      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
    
          # Store class
          y[i] = self.labels[ID]

      X_gen = self.augmentor.flow(X, batch_size=self.batch_size, shuffle=False)
      """do not perform shuffle here, the shuffling is performed beforehand
       by your custom class anyway, you just want the transformations to be 
      applied, and above all you want to keep your images synced with the 
      labels.""" 
      
      return next(X_gen), tkf.utils.to_categorical(y, num_classes=self.n_classes)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM