[英]Keras predict_generator corrupted images
我正在嘗試使用 python 3 中的 predict_generator 以 keras 和 tensorflow 作為后端,用我的訓練模型預測幾百萬張圖像。 生成器和模型預測有效,但是,目錄中的某些圖像已損壞或損壞並導致 predict_generator 停止並拋出錯誤。 刪除圖像后,它會再次工作,直到下一個損壞/損壞的圖像通過該函數輸入。
由於圖像太多,運行腳本來打開每個圖像並刪除引發錯誤的圖像是不可行的。 有沒有辦法將“如果損壞則跳過圖像”參數合並到生成器或目錄函數中?
任何幫助是極大的贊賞!
由於它發生在預測過程中,因此如果您跳過任何圖像或批處理,則需要跟蹤跳過哪些圖像,以便您可以將預測分數正確映射到圖像文件名。
基於這個想法,我的 DataGenerator 實現了一個有效的圖像索引跟蹤器。 特別是,關注變量valid_index
,其中跟蹤有效圖像的索引。
class DataGenerator(keras.utils.Sequence):
def __init__(self, df, batch_size, verbose=False, **kwargs):
self.verbose = verbose
self.df = df
self.batch_size = batch_size
self.valid_index = kwargs['valid_index']
self.success_count = self.total_count = 0
def __len__(self):
return int(np.ceil(self.df.shape[0] / float(self.batch_size)))
def __getitem__(self, idx):
print('generator is loading batch ',idx)
batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
self.total_count += batch_df.shape[0]
# return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
x = load_batch_image_to_arrays(batch_df['image_file_names'])
# filter out corrupted images
tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
u is not None]
# boundary case. # all image failed, return another random batch
if len(tmp) == 0:
print('[ERROR] All images loading failed')
# based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
# Keras will automatically find the next batch if it returns None
return None
print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
self.success_count += len(tmp)
x, batch_index = zip(*tmp)
x = np.stack(x) # list to np.array
self.valid_index[idx] = batch_index
# follow preprocess input function provided by keras
x = resnet50_preprocess(np.array(x, dtype=np.float))
return x
def on_epoch_end(self):
print('total image count', self.total_count)
print('successful images count', self.success_count)
self.success_count = self.total_count = 0 # reset count after one epoch ends.
在預測過程中。
predictions = model.predict_generator(
generator=data_gen,
workers=10,
use_multiprocessing=False,
max_queue_size=20,
verbose=1
).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.