简体   繁体   中英

How to use tf.data.Dataset.from_generator() to load only one batch at a time from the dataset?

I want to train a CNN and I am trying to feed the model with one batch at a time, directly from a numpy memmap, not having to load the whole dateset to the memory, using tf.data.Dataset.from_generator() . I am using tf2.2 and the GPU for fitting. The dataset is a sequence of 3D matrices (NCHW format). The label of each case is the next 3D matrix. The problem is that it still loads the whole dataset to the memory.

Here is a short reproducible example:

import numpy as np
from numpy.lib.format import open_memmap
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

tf.config.list_physical_devices("GPU")


# create and initialize the memmap
ds_shape = (20000, 3, 50, 50)
ds_mmap = open_memmap("ds.npy",
                      mode='w+',
                      dtype=np.dtype("float64"),
                      shape=ds_shape)
ds_mmap = np.random.rand(*ds_shape)

len_ds = len(ds_mmap)          # 20000
len_train = int(0.6 * len_ds)  # 12000
len_val = int(0.2 * len_ds)    # 4000
len_test = int(0.2 * len_ds)   # 4000
batch_size = 32
epochs = 50

I tried 2 ways of generating train-val-test datasets (Also, if anyone could comment on pros and cons, it would be more than welcome)

1.

def gen(ds_mmap, start, stop):
  for i in range(start, stop):
    yield (ds_mmap[i], ds_mmap[i + 1])

tvt = {"train": None, "val": None, "test": None}
tvt_limits = {
  "train": (0, len_train),
  "val": (len_train, len_train + len_val),
  "test": (len_train + len_val, len_ds -1)  # -1 because the last case does not have a label
}

for ds_type, ds in tvt.items():
  start, stop = tvt_limits[ds_type]
  ds = tf.data.Dataset.from_generator(
    generator=gen,
    output_types=(tf.float64, tf.float64),
    output_shapes=(ds_shape[1:], ds_shape[1:]),
    args=[ds_mmap, start, stop]
  )

train_ds = (
  tvt["train"]
  .shuffle(len_ds, reshuffle_each_iteration=False)
  .batch(batch_size)
)
val_ds = tvt["val"].batch(batch_size)
test_ds = tvt["test"].batch(batch_size)
def gen(ds_mmap):
  for i in range(len(ds_mmap) - 1):
    yield (ds_mmap[i], ds_mmap[i + 1])

ds = tf.data.Dataset.from_generator(
  generator=gen,
  output_types=(tf.float64, tf.float64),
  output_shapes=(ds_shape[1:], ds_shape[1:])
  args=[ds_mmap]
)

train_ds = (
  ds
  .take(len_train)
  .shuffle(len_ds, reshuffle_each_iteration=False)
  .batch(batch_size)
)
val_ds = ds.skip(len_train).take(len_val).batch(batch_size)
test_ds = ds.skip(len_train + len_val).take(len_test - 1).batch(batch_size)

Both ways work, but will bring the whole dataset to the memory.

model = keras.Sequential([
  layers.Conv2D(64, (3, 3), input_shape=ds_shape[1:],
                activation="relu", data_format="channels_first"),
  layers.MaxPooling2D(data_format="channels_first"),
  layers.Conv2D(128, (3, 3),
                activation="relu", data_format="channels_first"),
  layers.MaxPooling2D(data_format="channels_first"),
  layers.Flatten(),
  layers.Dense(8182, activation="relu"),
  layers.Dense(np.prod(ds_shape[1:])),
  layers.Reshape(ds_shape[1:])
])

model.compile(loss="mean_aboslute_error",
              optimizer="adam",
              metrics=[tf.keras.metrics.MeanSquaredError()])

hist = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs,
  # steps_per_epoch=len_train // batch_size,
  # validation_steps=len_val // batch_size,
  shuffle=True
)

An alternative was to subclass keras.utils.Sequence . The idea is to generate the whole batch.

Quoting the docs:

Sequence are a safer way to do multiprocessing. This structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

To do so, it is needed to provide __len__() and __getitem__() methods.

For the current example:

class DS(keras.utils.Sequence):
  
  def __init__(self, ds_mmap, start, stop, batch_size):
    self.ds = ds_mmap[start: stop]
    self.batch_size = batch_size

  def __len__(self):
    # divide-ceil
    return -(-len(self.ds) // self.batch_size)

  def __getitem__(self, idx):
    start = idx * self.batch_size
    stop = (idx + 1) * self.batch_size
    batch_y = self.ds[start + 1: stop + 1]
    batch_x = self.ds[start: stop][: len(batch_y)]
    return batch_x, batch_y
for ds_type, ds in tvt.items():
  start, stop = tvt_limits[ds_type]
  ds = DS(ds_mmap, start, stop, batch_size)

In that case, it is needed to explicitly define the number of steps and NOT pass a batch_size :

hist = model.fit(
  tvt["train"],
  validation_data=tvt["val"],
  epochs=epochs,
  steps_per_epoch=len_train // batch_size,
  validation_steps=len_val // batch_size,
  shuffle=True
)

Still, I didn't get from_generator() to work and I would like to know how.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM