I am trying to work with the ImageDataGenerator
with ResNet50
architecture and have used
from keras.applications.resnet import preprocess_input
ImagedataGenerator(preprocessing_function=preprocess_input)
The problem is that it does not have any support for Grayscale images as it is only used for RGB images. I went for the source code and found in this keras/applications/resnet link that the preprocess_input
is same for every architecture. Following the github code flow I happen to find the implementation of preprocess_input
defined as:
def preprocess_input(x):
return x.astype('float32').reshape((-1,) + input_shape) / 255
It is not even doing anything also same function is being used for VGG,MobileNet,ResNet
and do on.
After searching I found this source code of another preprocess_input which is using only and only 3 channels.
I have 3 questions now:
preprocess_input
Below is the error traceback:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-27-2591f0ea303d> in <module>()
2 epochs=1,
3 validation_data=val_data,
----> 4 callbacks=callbacks)
11 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1061 use_multiprocessing=use_multiprocessing,
1062 model=self,
-> 1063 steps_per_execution=self._steps_per_execution)
1064
1065 # Container that configures and calls `tf.keras.Callback`s.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution)
1115 use_multiprocessing=use_multiprocessing,
1116 distribution_strategy=ds_context.get_strategy(),
-> 1117 model=model)
1118
1119 strategy = ds_context.get_strategy()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, shuffle, workers, use_multiprocessing, max_queue_size, model, **kwargs)
914 max_queue_size=max_queue_size,
915 model=model,
--> 916 **kwargs)
917
918 @staticmethod
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, workers, use_multiprocessing, max_queue_size, model, **kwargs)
784 # Since we have to know the dtype of the python generator when we build the
785 # dataset, we have to look at a batch to infer the structure.
--> 786 peek, x = self._peek_and_restore(x)
787 peek = self._standardize_batch(peek)
788 peek = _process_tensorlike(peek)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py in _peek_and_restore(x)
918 @staticmethod
919 def _peek_and_restore(x):
--> 920 return x[0], x
921
922 def _handle_multiprocessing(self, x, workers, use_multiprocessing,
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in __getitem__(self, idx)
63 index_array = self.index_array[self.batch_size * idx:
64 self.batch_size * (idx + 1)]
---> 65 return self._get_batches_of_transformed_samples(index_array)
66
67 def __len__(self):
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/iterator.py in _get_batches_of_transformed_samples(self, index_array)
237 params = self.image_data_generator.get_random_transform(x.shape)
238 x = self.image_data_generator.apply_transform(x, params)
--> 239 x = self.image_data_generator.standardize(x)
240 batch_x[i] = x
241 # optionally save augmented images to disk for debugging purposes
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/image_data_generator.py in standardize(self, x)
706 """
707 if self.preprocessing_function:
--> 708 x = self.preprocessing_function(x)
709 if self.rescale:
710 x *= self.rescale
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/resnet.py in preprocess_input(x, data_format)
522 def preprocess_input(x, data_format=None):
523 return imagenet_utils.preprocess_input(
--> 524 x, data_format=data_format, mode='caffe')
525
526
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/imagenet_utils.py in preprocess_input(x, data_format, mode)
114 if isinstance(x, np.ndarray):
115 return _preprocess_numpy_input(
--> 116 x, data_format=data_format, mode=mode)
117 else:
118 return _preprocess_symbolic_input(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/applications/imagenet_utils.py in _preprocess_numpy_input(x, data_format, mode)
231 else:
232 x[..., 0] -= mean[0]
--> 233 x[..., 1] -= mean[1]
234 x[..., 2] -= mean[2]
235 if std is not None:
IndexError: index 1 is out of bounds for axis 2 with size 1
I can only answer question number 3. This is what I came up with to make it work. I hope it helps. I also added some fully connected layers to use with transfer learning, and it was a classification problem with 3 classes, hence the last softmax layer.
resnet_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
ipt = layers.Input(shape=(resolution, resolution, 1), name="input")
x = tf.keras.layers.Concatenate()([ipt, ipt, ipt])
x = tf.cast(x, tf.float32)
x = tf.keras.applications.resnet50.preprocess_input(x)
x = resnet_model(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.35)(x)
out = layers.Dense(3, activation='softmax')(x)
full_model = Model(inputs=ipt, outputs=out)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.