簡體   English   中英

如何使用 keras 和 tensorflow 的 ImageDataGenerator 執行數據增強

[英]How to perform data augmentation using keras and tensorflow's ImageDataGenerator

我很難理解如何使用 tensorflow 實現數據增強。 我有一個數據集(圖像),分為兩個子集; 培訓和測試。 在我使用各種參數調用ImageDataGenerator函數后,我是否需要保存圖像(如使用flow() )或者 Tensorflow 是否會在 model 訓練時增加我的數據?

這是我實現的代碼:

# necessary imports

train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    brightness_range=(0.3, 1.0),
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

training_directory = '/tmp/dataset/training'
testing_directory = '/tmp/dataset/testing'

training_set = train_datagen.flow_from_directory(
    training_directory,
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary',
    subset='training'
)

test_set = train_datagen.flow_from_directory(
    testing_directory,
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary',
    subset='validation'
)

# creating a sequential model
...
# fitting and data plotting

model總結:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 148, 148, 32)      896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 74, 74, 32)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 72, 72, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 36, 36, 64)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 34, 34, 128)       73856
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 17, 17, 128)       0
_________________________________________________________________
dropout (Dropout)            (None, 17, 17, 128)       0
_________________________________________________________________
flatten (Flatten)            (None, 36992)             0
_________________________________________________________________
dense (Dense)                (None, 512)               18940416
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 513
=================================================================
Total params: 19,034,177
Trainable params: 19,034,177
Non-trainable params: 0
_________________________________________________________________

您不必保存新數據。

調用流方法時,數據會即時擴充並作為 model 的輸入。

因此,數據正在實時生成並立即輸入您的 model。

您無需保存數據。 增強數據(訓練/測試)直接輸入 model 以使用訓練和測試數據生成器進行訓練或評估步驟。

這是使用創建的數據生成器train_generatortest_generator更新了所有步驟的代碼。

 datagenerator = ImageDataGenerator(
    rescale=1. / 255,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    brightness_range=(0.3, 1.0),
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)
 
training_directory = '/tmp/dataset/training'
testing_directory = '/tmp/dataset/testing'

train_generator = datagenerator.flow_from_directory(
    training_directory,
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary',
    subset='training'
)

test_generator = datagenerator.flow_from_directory(
    testing_directory,
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary',
    subset='validation'
)

# Build and compile the model
....
# Get the number of steps per epoch for each of the data generators
train_steps_per_epoch = train_generator.n // train_generator.batch_size
test_steps_per_epoch = test_generator.n // test_generator.batch_size

# Fit the model
model.fit_generator(train_generator, steps_per_epoch=train_steps_per_epoch, epochs=your_nepochs)

# Evaluate the model
model.evaluate_generator(test_generator, steps=test_steps_per_epoch)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM