简体   繁体   中英

Keras Resnet50 preprocess_input gives error with Grayscale images when used with ImageDataGenerator

I am trying to work with the ImageDataGenerator with ResNet50 architecture and have used

from keras.applications.resnet import preprocess_input
ImagedataGenerator(preprocessing_function=preprocess_input)

The problem is that it does not have any support for Grayscale images as it is only used for RGB images. I went for the source code and found in this keras/applications/resnet link that the preprocess_input is same for every architecture. Following the github code flow I happen to find the implementation of preprocess_input defined as:

def preprocess_input(x):
    return x.astype('float32').reshape((-1,) + input_shape) / 255

It is not even doing anything also same function is being used for VGG,MobileNet,ResNet and do on.

After searching I found this source code of another preprocess_input which is using only and only 3 channels.

I have 3 questions now:

  1. Why is every Transfer Learning architecture using the same preprocess_input
  2. Which one is being used? Two liner code given at keras/applications/resnet or the 3 channel code given at tensorflow/python/keras...?
  3. What Can I do to make preprocessing of Grayscale images ?

Below is the error traceback:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-27-2591f0ea303d> in <module>()
      2                 epochs=1,
      3                 validation_data=val_data,
----> 4                 callbacks=callbacks)

11 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1061           use_multiprocessing=use_multiprocessing,
   1062           model=self,
-> 1063           steps_per_execution=self._steps_per_execution)
   1064 
   1065       # Container that configures and calls `tf.keras.Callback`s.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution)
   1115         use_multiprocessing=use_multiprocessing,
   1116         distribution_strategy=ds_context.get_strategy(),
-> 1117         model=model)
   1118 
   1119     strategy = ds_context.get_strategy()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, shuffle, workers, use_multiprocessing, max_queue_size, model, **kwargs)
    914         max_queue_size=max_queue_size,
    915         model=model,
--> 916         **kwargs)
    917 
    918   @staticmethod

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, workers, use_multiprocessing, max_queue_size, model, **kwargs)
    784     # Since we have to know the dtype of the python generator when we build the
    785     # dataset, we have to look at a batch to infer the structure.
--> 786     peek, x = self._peek_and_restore(x)
    787     peek = self._standardize_batch(peek)
    788     peek = _process_tensorlike(peek)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in _peek_and_restore(x)
    918   @staticmethod
    919   def _peek_and_restore(x):
--> 920     return x[0], x
    921 
    922   def _handle_multiprocessing(self, x, workers, use_multiprocessing,

/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in __getitem__(self, idx)
     63         index_array = self.index_array[self.batch_size * idx:
     64                                        self.batch_size * (idx + 1)]
---> 65         return self._get_batches_of_transformed_samples(index_array)
     66 
     67     def __len__(self):

/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in _get_batches_of_transformed_samples(self, index_array)
    237                 params = self.image_data_generator.get_random_transform(x.shape)
    238                 x = self.image_data_generator.apply_transform(x, params)
--> 239                 x = self.image_data_generator.standardize(x)
    240             batch_x[i] = x
    241         # optionally save augmented images to disk for debugging purposes

/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/image_data_generator.py in standardize(self, x)
    706         """
    707         if self.preprocessing_function:
--> 708             x = self.preprocessing_function(x)
    709         if self.rescale:
    710             x *= self.rescale

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/resnet.py in preprocess_input(x, data_format)
    522 def preprocess_input(x, data_format=None):
    523   return imagenet_utils.preprocess_input(
--> 524       x, data_format=data_format, mode='caffe')
    525 
    526 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/imagenet_utils.py in preprocess_input(x, data_format, mode)
    114   if isinstance(x, np.ndarray):
    115     return _preprocess_numpy_input(
--> 116         x, data_format=data_format, mode=mode)
    117   else:
    118     return _preprocess_symbolic_input(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/imagenet_utils.py in _preprocess_numpy_input(x, data_format, mode)
    231   else:
    232     x[..., 0] -= mean[0]
--> 233     x[..., 1] -= mean[1]
    234     x[..., 2] -= mean[2]
    235     if std is not None:

IndexError: index 1 is out of bounds for axis 2 with size 1

I can only answer question number 3. This is what I came up with to make it work. I hope it helps. I also added some fully connected layers to use with transfer learning, and it was a classification problem with 3 classes, hence the last softmax layer.


    resnet_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)

    ipt = layers.Input(shape=(resolution, resolution, 1), name="input")
    x = tf.keras.layers.Concatenate()([ipt, ipt, ipt])
    x = tf.cast(x, tf.float32)
    x = tf.keras.applications.resnet50.preprocess_input(x)
    x = resnet_model(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.35)(x)
    out = layers.Dense(3, activation='softmax')(x)

    full_model = Model(inputs=ipt, outputs=out)

I took the idea from here and here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM