簡體   English   中英

如何在keras中使用flow_from_dataframe輸入多個圖像?

[英]How to input multiple images with flow_from_dataframe in keras?

我一直在嘗試創建Siamese 模型來查找 2 個圖像之間的圖像相似性(它有 2 個輸入圖像)。 一開始我用一個小數據集對其進行了測試,它安裝在我的 RAM 中並且運行良好。 現在,我想增加訓練樣本大小,為此我創建了images.csv文件。 在這個文件中,我有 3 列: image_1image_2similarity

image_1image_2是圖像的絕對路徑。 similarity為 0 或 1。

我試過

generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col=['image_1', 'image_2'],
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

但得到這個錯誤:

ValueError: x_col=['image_1', 'image_2'] 列中的所有值都必須是字符串。

刪除image_2並且x_col=image_1錯誤消失后,它只有 1 個輸入圖像。

我該怎么辦?

您不能使用該方法從單個生成器中傳輸兩個圖像,它旨在處理來自文檔的一個圖像

x_col:字符串,數據框中包含文件名的列(如果目錄為 None,則為絕對路徑)。

相反,您可以做的是創建兩個生成器,更合適地允許您的網絡有兩個輸入

in1 = generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col='image_1',
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

in2 = generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col='image_2',
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

然后使用接受兩個圖像輸入的函數式 API構建模型:

input_image1 = Input(shape=(64, 64, 1))
input_image2 = Input(shape=(64, 64, 1))
# ... all other layers to create output_layer
model = Model([input_image1, input_image2], output)
# ...

這更能反映您的模型實際上有 2 個輸入作為圖像。

在@nuric 的幫助下,我能夠輸入多個圖像。 這是創建流的完整代碼:

def get_flow_from_dataframe(generator, dataframe,
                            image_shape=(64, 64),
                            subset='training',
                            color_mode='grayscale', batch_size=64):
    train_generator_1 = generator.flow_from_dataframe(dataframe, target_size=image_shape,
                                                      color_mode=color_mode,
                                                      x_col='image_1',
                                                      y_col='prediction',
                                                      class_mode='binary',
                                                      shuffle=True,
                                                      batch_size=batch_size,
                                                      seed=7,
                                                      subset=subset, drop_duplicates=False)

    train_generator_2 = generator.flow_from_dataframe(dataframe, target_size=image_shape,
                                                      color_mode=color_mode,
                                                      x_col='image_2',
                                                      y_col='prediction',
                                                      class_mode='binary',
                                                      shuffle=True,
                                                      batch_size=batch_size,
                                                      seed=7,
                                                      subset=subset, drop_duplicates=False)
    while True:
        x_1 = train_generator_1.next()
        x_2 = train_generator_2.next()

        yield [x_1[0], x_2[0]], x_1[1]

fit_generator 的完整代碼:

train_gen = get_flow_from_dataframe(generator, dataframe, image_shape=(64, 64),
                                        color_mode='rgb',
                                        batch_size=batch_size)
valid_gen = get_flow_from_dataframe(generator, dataframe, image_shape=(64, 64),
                                        color_mode='rgb',
                                        batch_size=batch_size,
                                        subset='validation')

model.fit_generator(train_gen, epochs=50,
                        steps_per_epoch=step_size,
                        validation_data=valid_gen,
                        validation_steps=step_size,
                        callbacks=get_call_backs('../models/model_1.h5', monitor='val_acc'),
                        )

此外,正如我所見,內存消耗是巨大的。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM