简体   繁体   中英

Keras predict_generator corrupted images

I am trying to predict several million images with my trained model using a predict_generator in python 3 with keras and tensorflow as backend. The generator and the model predictions work, however, some images in the directory are broken or corrupted and cause the predict_generator to stop and throw an error. Once the image is removed it works again until the next corrupted/broken image gets fed through the function.

Since there are so many images it is not feasible to run a script to open every image and delete the ones that are throwing an error. Is there a way to incorporate a "skip image if broken" argument into the generator or flow from directory function?
Any help is greatly appreciated!

There's no such argument in ImageDataGenerator and neither in flow_from_directory method as you can see int the Keras docs for both ( here and here ). One workaround would be to extend the ImageDataGenerator class and overload the flow_from_directory method to check wether the image is corrupted or not before yeld it in the generator. Here you can find it's source code.

Since it happens during prediction, if you skip any image or batch, you need to keep track of which images are skipped, so that you can correctly map the prediction scores to the image file name.

Based on this idea, my DataGenerator is implemented with a valid image index tracker. In particular, focus on the variable valid_index where index of valid images are tracked.

class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size, verbose=False, **kwargs):
        self.verbose = verbose
        self.df = df
        self.batch_size = batch_size
        self.valid_index = kwargs['valid_index']
        self.success_count = self.total_count = 0

    def __len__(self):
        return int(np.ceil(self.df.shape[0] / float(self.batch_size)))

    def __getitem__(self, idx):
        print('generator is loading batch ',idx)
        batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
        self.total_count += batch_df.shape[0]

        # return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
        x = load_batch_image_to_arrays(batch_df['image_file_names'])

        # filter out corrupted images
        tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
               u is not None]

        # boundary case. # all image failed, return another random batch
        if len(tmp) == 0:
            print('[ERROR] All images loading failed')
            # based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
            # Keras will automatically find the next batch if it returns None
            return None

        print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
        self.success_count += len(tmp)

        x, batch_index = zip(*tmp) 
        x = np.stack(x)  # list to np.array
        self.valid_index[idx] = batch_index

        # follow preprocess input function provided by keras
        x = resnet50_preprocess(np.array(x, dtype=np.float))
        return x

    def on_epoch_end(self):
        print('total image count', self.total_count)
        print('successful images count', self.success_count)
        self.success_count = self.total_count = 0 # reset count after one epoch ends.

During prediction.

predictions = model.predict_generator(
            generator=data_gen,
            workers=10,
            use_multiprocessing=False,
            max_queue_size=20,
            verbose=1
        ).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
    indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM