简体   繁体   English

使用自定义数据生成器为Keras模型拟合大量数据

[英]Fitting a Keras model with large amount of data using a custom data generator

I'm trying to fit my Keras model with quite large amount of data. 我正在努力使我的Keras模型适合大量数据。

To do this, I'm using custom data generators and model.fit_generator function. 为此,我使用自定义数据生成器和model.fit_generator函数。

However, I can't seem to understand if I'm doing this correctly. 但是,我似乎无法理解我是否正确地这样做了。

Here's what I have: 这就是我所拥有的:

from os.path import join

import cv2
import numpy as np
from keras.models import Sequential
from keras.layers.core import Flatten, Dense, Dropout
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau

# The function returns a list of image names from folder
from data.preprocessing import get_list_of_images


class VGG19(object):
    def __init__(self, weights_path=None, train_folder='data/train', validation_folder='data/val'):
        self.weights_path = weights_path
        self.model = self._init_model()

        if weights_path:
            self.model.load_weights(weights_path)
        else:
            self.datagen = self._init_datagen()
            self.train_folder = train_folder
            self.validation_folder = validation_folder
            self.model.compile(
                loss='binary_crossentropy',
                optimizer='adam',
                metrics=['accuracy']
            )

    def fit(self, batch_size=32, nb_epoch=10):
        self.model.fit_generator(
            self._generate_data_from_folder(self.train_folder), 32,
            nb_epoch,
            verbose=1,
            callbacks=[
                TensorBoard(log_dir='./logs', write_images=True),
                ModelCheckpoint(filepath='weights.{epoch:02d}-{val_loss:.2f}.hdf5', monitor='val_loss'),
                ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.001)
            ],
            validation_data=self._generate_data_from_folder(self.validation_folder),
            nb_val_samples=32
        )

    def predict(self, X, batch_size=32, verbose=1):
        return self.model.predict(X, batch_size=batch_size, verbose=verbose)

    def predict_proba(self, X, batch_size=32, verbose=1):
        return self.model.predict_proba(X, batch_size=batch_size, verbose=verbose)

    def _init_model(self):
        model = Sequential()
        # model definition goes here...
        return model

    def _init_datagen(self):
        return ImageDataGenerator(
            featurewise_center=True,
            samplewise_center=False,
            featurewise_std_normalization=True,
            samplewise_std_normalization=False,
            zca_whitening=False,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            vertical_flip=True
        )

    def _generate_data_from_folder(self, folder_path):
        while 1:
            images = get_list_of_images(folder_path)

            for image_path in images:
                x = cv2.imread(join(folder_path, image_path))
                y = 0 if image_path.split('.')[0] == 'dog' else 1

                yield (x, y)

My dataset consists of images with names like: 我的数据集由名称如下的图像组成:

  • cat.[number].jpg , ie: cat.124.jpg cat.[number].jpg ,即: cat.124.jpg

  • dog.[number].jpg , ie: dog.64.jpg dog.[number].jpg ,即: dog.64.jpg

So, basically, I'm trying to train a model to perform a binary cat-dog classification. 所以,基本上,我正在尝试训练模型来执行二元猫狗分类。


Is my _generate_data_from_folder function correctly implemented for mini-batch optimization? 我的_generate_data_from_folder函数是否正确实现了小批量优化?

How can I add the usage of ImageDataGenerator to my _generate_data_from_folder function (from the _init_datagen function)? 如何将ImageDataGenerator的用法添加到_generate_data_from_folder函数(来自_init_datagen函数)?

Okay, here's the github link to my final version of the project that I got working: 好的,这是我工作的项目最终版本的github链接:

https://github.com/yakovenkodenis/dogs-vs-cats-kaggle https://github.com/yakovenkodenis/dogs-vs-cats-kaggle

Hope, it helps somebody 希望,它有助于某人

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM