簡體   English   中英

Keras predict_generator 損壞的圖像

[英]Keras predict_generator corrupted images

我正在嘗試使用 python 3 中的 predict_generator 以 keras 和 tensorflow 作為后端,用我的訓練模型預測幾百萬張圖像。 生成器和模型預測有效,但是,目錄中的某些圖像已損壞或損壞並導致 predict_generator 停止並拋出錯誤。 刪除圖像后,它會再次工作,直到下一個損壞/損壞的圖像通過該函數輸入。

由於圖像太多,運行腳本來打開每個圖像並刪除引發錯誤的圖像是不可行的。 有沒有辦法將“如果損壞則跳過圖像”參數合並到生成器或目錄函數中?
任何幫助是極大的贊賞!

ImageDataGeneratorflow_from_directory方法中都沒有這樣的參數, ImageDataGenerator您可以在flow_from_directory文檔中看到兩者( 此處此處)。 一種解決方法是擴展ImageDataGenerator類並重載flow_from_directory方法以在生成器中生成圖像之前檢查圖像是否已損壞。 在這里你可以找到它的源代碼。

由於它發生在預測過程中,因此如果您跳過任何圖像或批處理,則需要跟蹤跳過哪些圖像,以便您可以將預測分數正確映射到圖像文件名。

基於這個想法,我的 DataGenerator 實現了一個有效的圖像索引跟蹤器。 特別是,關注變量valid_index ,其中跟蹤有效圖像的索引。

class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size, verbose=False, **kwargs):
        self.verbose = verbose
        self.df = df
        self.batch_size = batch_size
        self.valid_index = kwargs['valid_index']
        self.success_count = self.total_count = 0

    def __len__(self):
        return int(np.ceil(self.df.shape[0] / float(self.batch_size)))

    def __getitem__(self, idx):
        print('generator is loading batch ',idx)
        batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
        self.total_count += batch_df.shape[0]

        # return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
        x = load_batch_image_to_arrays(batch_df['image_file_names'])

        # filter out corrupted images
        tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
               u is not None]

        # boundary case. # all image failed, return another random batch
        if len(tmp) == 0:
            print('[ERROR] All images loading failed')
            # based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
            # Keras will automatically find the next batch if it returns None
            return None

        print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
        self.success_count += len(tmp)

        x, batch_index = zip(*tmp) 
        x = np.stack(x)  # list to np.array
        self.valid_index[idx] = batch_index

        # follow preprocess input function provided by keras
        x = resnet50_preprocess(np.array(x, dtype=np.float))
        return x

    def on_epoch_end(self):
        print('total image count', self.total_count)
        print('successful images count', self.success_count)
        self.success_count = self.total_count = 0 # reset count after one epoch ends.

在預測過程中。

predictions = model.predict_generator(
            generator=data_gen,
            workers=10,
            use_multiprocessing=False,
            max_queue_size=20,
            verbose=1
        ).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
    indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM