简体   繁体   中英

Gradient accumulation in tensorflow 2.x / keras

I'm trying to implement Gradient accumulation on TF2.x. All implementations I've found are either for TF1.x or for the old keras interface. I don't think there is an implementation out there (though I'd be very happy to be proven wrong on this).

Here's what I'm working with:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt


class SimpleTrainStepModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x, training = True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y, y_pred)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}


class GradAccumModel(Model):
    def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel, self).fit(*args,
                                               batch_size = self.batch_size,
                                               #validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
                                               **kwargs)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj, i, j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x[:step], training = True)  # Forward pass
            loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y[:step], y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i, *args):
            return i < self.batch_size
        def body(i, grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step], training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
            _grad = tape.gradient(loss, self.trainable_variables)

            for g,_g in zip(grad, _grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step], y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)


        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #


        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


if __name__ == '__main__':
    (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
                                            [{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
                                            ['blue', 'green', 'yellow', 'red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28, 28))
            y = x
            y = Flatten()(y)
            y = Dense(128, activation = 'sigmoid')(y)
            y = Dense(10, activation = 'softmax')(y)

            model = MODEL(x, y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(1e-4),
                          metrics = ['acc'])

            hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
            plt.plot(hist.history['val_acc'], color = colour, alpha = .25)

    plt.title('')
    plt.xscale('symlog')
    plt.yscale('logit')
    plt.show()

I've been able to verify that it does actually save gpu memory. However, the end result is not the same as the normal Model.fit .

验证

特写

As you can see, the first three Model.fit s are well clustered and give the same results. But when the the while cycle comes into play the training is quite different.

Anyone have any idea why this is happening?

After a lot more attempts I found the solution, It seems that the main problem was the compound assignments of the gradients which don't work quite as I was expecting. Here is my final solution for anyone who might be interested. It includes the extra stuff for distributed, mixed precision trainings, and nested input/output.

from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model


class Model(_Model):
    def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
        """
        Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.

        Parameters
        ----------
        batch_size : int
            same as in Model.fit
        grad_accum_steps : int
            Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
        """
        if grad_accum_steps == 1:
            super().fit(*args, batch_size = batch_size, **kwargs)

        self.train_function = None
        num_workers = ds_context.get_strategy().num_replicas_in_sync
        if batch_size % (grad_accum_steps * num_workers) != 0:
            raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')

        self._grad_accum_ = grad_accum_steps
        self._batch_size_ = batch_size
        self._num_workers_ = num_workers
        train_step_backup = self.train_step
        self.train_step = self._train_step_
        out = super(self).fit(*args,
                              batch_size = self._batch_size_, # TODO maybe consider validation batch size
                              **kwargs)

        del self._grad_accum_
        del self._batch_size_
        del self._num_workers_
        self.train_step = train_step_backup
        return out

    def _train_step_(self, data):
        """
        Custom training step taking into account gradient accumulation for low memory training
        """

        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        def slice_map(struct, start, stop): # dealing with nasty nested structures
            if struct is None:
                return None # special case for sample_weight

            return nest.map_structure(lambda x: x[start:stop], struct)



        # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
        step = self._batch_size_ // self._num_workers_ // self._grad_accum_
        x_ = slice_map(x, 0, step)
        y_ = slice_map(y, 0, step)
        w_ = slice_map(sample_weight, 0, step)

        with tf.GradientTape() as tape:

            y_pred = self(x_, training = True)  # Forward pass
            loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
            if isinstance(self.optimizer, lso.LossScaleOptimizer):
                loss = self.optimizer.get_scaled_loss(loss)

        gradients = tape.gradient(loss, self.trainable_variables)
        gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
        self.compiled_metrics.update_state(y_, y_pred)

        i = tf.constant(step)
        def cond(i, *args):
            return i < self._batch_size_

        def body(i, grad):
            x_ = slice_map(x, i, i + step)
            y_ = slice_map(y, i, i + step)
            w_ = slice_map(sample_weight, i, i + step)

            with tf.GradientTape() as tape:
                y_pred = self(x_, training = True) # Forward pass
                loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                if isinstance(self.optimizer, lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)

            _grad = tape.gradient(loss, self.trainable_variables)
            _grad = [_g * (1./self._grad_accum_) for _g in _grad]

            grad = [g + _g for g,_g in zip(grad, _grad)]

            self.compiled_metrics.update_state(y_, y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
        # --------------------------------------------------------------------------------------------------------------



        # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
        aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))

        if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
            gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))

        if isinstance(self.optimizer, lso.LossScaleOptimizer):
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients = self.optimizer._clip_gradients(gradients)
        if self.trainable_variables:
            if aggregate_grads_outside_optimizer:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
            else:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # --------------------------------------------------------------------------------------------------------------


        return {m.name: m.result() for m in self.metrics}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM