繁体   English   中英

如何用状态编写Keras fit_generator的生成器?

[英]how to write a generator for Keras fit_generator with a state?

我正在尝试将大型数据集提供给keras模型。 数据集不适合内存。 它目前存储为一系列hd5f文件

我想用我的模型训练

model.fit_generator(my_gen, steps_per_epoch=30, epochs=10, verbose=1)

但是,在我可以在网上找到的所有示例中, my_gen仅用于对已加载的数据集执行数据扩充。 例如

def generator(features, labels, batch_size):

 # Create empty arrays to contain batch of features and labels#

 batch_features = np.zeros((batch_size, 64, 64, 3))
 batch_labels = np.zeros((batch_size,1))

 while True:
   for i in range(batch_size):
     # choose random index in features
     index= random.choice(len(features),1)
     batch_features[i] = some_processing(features[index])
     batch_labels[i] = labels[index]
   yield batch_features, batch_labels

就我而言,它需要像

def generator(features, labels, batch_size):    
 while True:
   for i in range(batch_size):
     # choose random index in features
     index= # SELECT THE NEXT FILE
     batch_features[i] = some_processing(features[files[index]])
     batch_labels[i] = labels[file[index]]
   yield batch_features, batch_labels

如何跟踪上一批中已读取的文件?

来自keras doc

generator:生成器或Sequence(keras.utils.Sequence)对象的实例,以避免在使用多处理时出现重复数据。 [...]

这意味着您可以编写一个继承自keras.utils.sequence的类

class ProductSequence(keras.utils.Sequence):
    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

__init__是初始化类。 __len__应返回每个纪元的批次数。 Keras将使用它来知道哪个索引可以传递给__getitem__ 然后__getitem__将根据索引返回批处理数据。 这里有一个简单的例子

使用这种方法,您可以简单地拥有一个内部类对象,您可以在其中保存已读取的文件。

我们假设您的数据是图像。 如果您有许多图像,您可能无法将所有图像加载到内存中,并且您希望分批从磁盘读取。

Keras flow_from _directory非常快,这样做,因为它这个在一个多线程的方式太多,但它需要所有的图像是在不同的文件,根据他们的阶级。 如果我们将所有图像放在同一个文件中,并将它们的类放在分离的文件中,我们可以使用生成器波纹管来加载我们的x,y数据。

import pandas as pd
import numpy as np
import cv2

#df_train:  data frame with class of every image
#dpath: path of images

classes=list(np.unique(df_train.label)) 
def batch_generator(ids):
    while True:
        for start in range(0, len(ids), batch_size):
            x_batch = []
            y_batch = []
            end = min(start + batch_size, len(ids))
            ids_batch = ids[start:end]
            for id in ids_batch:
                img = cv2.imread(dpath+'train/{}.png'.format(id)) #open cv read as BGR
                #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #BGR to RGB
                #img = cv2.resize(img, (224, 224), interpolation = cv2.INTER_CUBIC)
                #img = pre_process(img)
                labelname=df_train.label.loc[df_train.id==id].values
                labelnum=classes.index(labelname)
                x_batch.append(img)
                y_batch.append(labelnum)
            x_batch = np.array(x_batch)
            y_batch = to_categorical(y_batch,10) 
            yield x_batch, y_batch

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM