简体   繁体   中英

Use generator in TensorFlow/Keras to fit when the model gets 2 inputs

I want to train a model that uses an extra output layer to compute the loss (ArcFace) so the model gets two inputs: the features and the true label: [X, y] .

So far I did with the all data loaded at once by the following code:

print("Unzipping DataSet to NumPy arrays")
x_train, y_train = dataset2arrays(train_ds, labels)
x_val, y_val = dataset2arrays(val_ds, val_labels)

model.fit(x=[x_train, y_train],
          y=y_train,
          batch_size=10,
          validation_data=[[x_val, y_val], y_val],
          n_epochs=20,
         )

Now, this was done with "debugging" data, which is small (< 100 samples). The real training data is very large (> 300 GB of files) so I can't load all the data at once. Therefore I need to use a generator. In TensorFlow 2.8 a generator is implemented by inheriting from Keras Sequence class. The following generator is based on the example in https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

from os import path

import numpy as np
from tensorflow.keras.utils import Sequence
from keras.preprocessing.sequence import pad_sequences

from pre_processing import load_data


class DataGenerator(Sequence):
    """Generates data for Keras
    Sequence based data generator. Suitable for building data generator for training and prediction.
    """

    def __init__(self, list_IDs, labels, n_classes, input_path, target_path,
                 to_fit=True, batch_size=20, shuffle=True):
        """Initialization
        :param list_IDs: list of all 'label' ids to use in the generator
        :param to_fit: True to return X and y, False to return X only
        :param batch_size: batch size at each iteration
        :param shuffle: True to shuffle label indexes after every epoch
        """
        self.input_path = input_path
        self.target_path = target_path
        self.list_IDs = list_IDs
        self.labels = labels
        self.n_classes = n_classes
        self.to_fit = to_fit
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        """Denotes the number of batches per epoch
        :return: number of batches per epoch
        """
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data
        :param index: index of the batch
        :return: X and y when fitting. X only when predicting
        """

        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        list_labels_temp = [self.labels[k] for k in indexes]

        # Generate data
        X = self._generate_X(list_IDs_temp)

        if self.to_fit:
            y = self._generate_y(list_labels_temp)
            # print(indexes) # for debugging
            return [X], y
            
        else:
            return [X]

    def on_epoch_end(self):
        """
        Updates indexes after each epoch
        """
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def _generate_X(self, list_IDs_temp):
        """Generates data containing batch_size images
        :param list_IDs_temp: list of label ids to load
        :return: batch of images
        """
        # Initialization
        X = []

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            # temp = self._load_input(self.input_path, ID)
            temp = load_data(path.join(self.input_path, ID))
            X.append(temp)

        X = pad_sequences(X, value=0, padding='post')

        return X

    def _generate_y(self, list_IDs_temp):
        """Generates data containing batch_size masks
        :param list_IDs_temp: list of label ids to load
        :return: batch if masks
        """
        # TODO: modify
        y = []

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            # y.append(self._load_target(self.target_path, ID))
            y.append(ID)

        # y = pad_sequences(y, value=0, padding='post')

        return y

The most important part is:

if self.to_fit:
            y = self._generate_y(list_labels_temp)
            print(indexes)
            
            # Option 1: 
            return [X], y
            
            # Option 2
            return tuple([[X], [y]])

            # Option 3
            return tuple(((X), (y)))
            
            # Option 4
            Xy = []
            for i in range(len(y)):
                Xy.append([X[i,:,:], y[i]])
            return Xy
             
            # Option 5
            Xy = []
            for i in range(len(y)):
                Xy.append(X[i,:,:])
            return tuple((Xy, y))

        else:
            return [X]

With all (or most) of the options I tried as the output which the generator returns.

The new fit is:

history = model.fit(gen,
          callbacks=callbacks,
          batch_size = 10,
          epochs =20 ,
          # validation_data = tuple(validation_data),
          shuffle=True,
          verbose = 1, # display training on the terminal
          )

With option 1 I get the following error:

 ValueError: Layer "ForTraining" expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None, None) dtype=int32>]

The other options don't work as well (most return the same error as above).

So what am I doing wrong? So how to make my generator return correctly the tensor needed for training (features X and their labels y on batch-size b )?

The following link may be relevant: https://github.com/pierluigiferrari/ssd_keras/issues/380

Note that I am running TensorFlow 2.8 on Python 3.9.5 on a laptop with Windows 10 and without GPU (the real training on the full dataset will take place on a much stronger machine. This laptop is used only for debugging).

Solution:

The following solves the problem and now the training is running (when I comment out validation monitoring and callbacks):

def __getitem__(self, index):
        """Generate one batch of data
        :param index: index of the batch
        :return: X and y when fitting. X only when predicting
        """

        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        list_labels_temp = [self.labels[k] for k in indexes]

        # Generate data
        X = self._generate_X(list_IDs_temp)

        if self.to_fit:
            # Training/Fit case
            y = self._generate_y(list_labels_temp)
            y = np.array(y).reshape((len(y),1))
            return (X, y), y
        
        else:
            # Prediction only
            return [X]

How do I use the generator for validation data? I created another generator (identical to the train generator) and put it in "validation data" and the training procedure was completed successfully (without throwing an exception). It seems this is the solution to the problem.

The correct modification is:

def __getitem__(self, index):
        """Generate one batch of data
        :param index: index of the batch
        :return: X and y when fitting. X only when predicting
        """

        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        list_labels_temp = [self.labels[k] for k in indexes]

        # Generate data
        X = self._generate_X(list_IDs_temp)

        if self.to_fit:
            # Training/Fit case
            y = self._generate_y(list_labels_temp)
            y = np.array(y).reshape((len(y),1))
            return (X, y), y
        
        else:
            # Prediction only
            return [X]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM