简体   繁体   中英

Preloading part of HDF5 file while performing other task

I'm training a deep learning classifier, which uses a HDF5 dataset that is too large to fit into memory. Therefore, I extract the data in batches of 256 and use these batches to train my classifier in the following way. The deep learning library that I use (Keras) provides the method model.train_on_batch(X_batch, y_batch) .

for i in range(n_batches_in_dset):
        X_batch, y_batch = load_partition('train', ind=[i*batch_size, (i+1)*batch_size])
        loss = model.train_on_batch(X_batch, y_batch)

It would make sense to prefetch the next batch of data while training on the current data using the GPU. How can do this in Python?

I've attached the code that I use for loading the data.

def load_hdf5(path, datapart, ind=None):
    f   = h5py.File(path, 'r')
    if ind is None:
        dat = f[datapart][:]
    else:
        dat = f[datapart][ind[0]:ind[1]]
    f.close()
    return np.array(dat)

def load_partition(name, ind=None):
    path = DEEP_ROOT + 'data/{}.h5'.format(name)
    X = load_hdf5(path, 'data', ind)
    y = load_hdf5(path, 'label', ind)
    X = np.swapaxes(X, 2, 3)
    y = np_utils.to_categorical(y)
    return X, y

probably the simplest thing to do is put the separate tasks in separate threads , with a synchronized queue to hand the batches between them. We'll use a separate thread for the data-reading part, and the main thread for the training portion.

import Queue, threading

data_queue = Queue.Queue(2) # a queue with two space for two "chunks"
sentinel = object()

#start the data-loading task
def load_task()
    for x in i in range(n_batches_in_dset):
        data_queue.put(load_partition('train', ind=[i*batch_size, (i+1)*batch_size]), True)
    # tell the other side we're "done"
    data_queue.put(sentinel, True)

threading.Thread(target=load_task).start()

while True:
    batch = data_queue.get(True)
    data_queue.task_done()
    if batch is sentinel:
        break # we're done now!
    X_batch, y_batch = batch
    loss = model.train_on_batch(X_batch, y_batch)

EDIT : we need to use Queue.task_done() to unblock the queue

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM