简体   繁体   English

tensorflow 2.x / keras 中的梯度累积

[英]Gradient accumulation in tensorflow 2.x / keras

I'm trying to implement Gradient accumulation on TF2.x.我正在尝试在 TF2.x 上实现梯度累积。 All implementations I've found are either for TF1.x or for the old keras interface.我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。 I don't think there is an implementation out there (though I'd be very happy to be proven wrong on this).我认为那里没有实现(尽管我很高兴被证明是错误的)。

Here's what I'm working with:这是我正在使用的:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt


class SimpleTrainStepModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x, training = True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y, y_pred)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}


class GradAccumModel(Model):
    def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel, self).fit(*args,
                                               batch_size = self.batch_size,
                                               #validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
                                               **kwargs)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj, i, j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        # FIRST GRADIENT
        with tf.GradientTape() as tape:
            y_pred = self(x[:step], training = True)  # Forward pass
            loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y[:step], y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i, *args):
            return i < self.batch_size
        def body(i, grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step], training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
            _grad = tape.gradient(loss, self.trainable_variables)

            for g,_g in zip(grad, _grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step], y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)


        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #


        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


if __name__ == '__main__':
    (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
                                            [{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
                                            ['blue', 'green', 'yellow', 'red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28, 28))
            y = x
            y = Flatten()(y)
            y = Dense(128, activation = 'sigmoid')(y)
            y = Dense(10, activation = 'softmax')(y)

            model = MODEL(x, y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(1e-4),
                          metrics = ['acc'])

            hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
            plt.plot(hist.history['val_acc'], color = colour, alpha = .25)

    plt.title('')
    plt.xscale('symlog')
    plt.yscale('logit')
    plt.show()

I've been able to verify that it does actually save gpu memory.我已经能够验证它确实保存了 gpu memory。 However, the end result is not the same as the normal Model.fit .但是,最终结果与正常的Model.fit

验证

特写

As you can see, the first three Model.fit s are well clustered and give the same results.如您所见,前三个Model.fit聚类良好并给出相同的结果。 But when the the while cycle comes into play the training is quite different.但是当while循环开始发挥作用时,训练就完全不同了。

Anyone have any idea why this is happening?有人知道为什么会这样吗?

After a lot more attempts I found the solution, It seems that the main problem was the compound assignments of the gradients which don't work quite as I was expecting.经过多次尝试,我找到了解决方案,似乎主要问题是渐变的复合分配,它不像我预期的那样工作。 Here is my final solution for anyone who might be interested.对于任何可能感兴趣的人,这是我的最终解决方案。 It includes the extra stuff for distributed, mixed precision trainings, and nested input/output.它包括用于分布式、混合精度训练和嵌套输入/输出的额外内容。

from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model


class Model(_Model):
    def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
        """
        Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.

        Parameters
        ----------
        batch_size : int
            same as in Model.fit
        grad_accum_steps : int
            Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
        """
        if grad_accum_steps == 1:
            super().fit(*args, batch_size = batch_size, **kwargs)

        self.train_function = None
        num_workers = ds_context.get_strategy().num_replicas_in_sync
        if batch_size % (grad_accum_steps * num_workers) != 0:
            raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')

        self._grad_accum_ = grad_accum_steps
        self._batch_size_ = batch_size
        self._num_workers_ = num_workers
        train_step_backup = self.train_step
        self.train_step = self._train_step_
        out = super(self).fit(*args,
                              batch_size = self._batch_size_, # TODO maybe consider validation batch size
                              **kwargs)

        del self._grad_accum_
        del self._batch_size_
        del self._num_workers_
        self.train_step = train_step_backup
        return out

    def _train_step_(self, data):
        """
        Custom training step taking into account gradient accumulation for low memory training
        """

        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None


        def slice_map(struct, start, stop): # dealing with nasty nested structures
            if struct is None:
                return None # special case for sample_weight

            return nest.map_structure(lambda x: x[start:stop], struct)



        # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
        step = self._batch_size_ // self._num_workers_ // self._grad_accum_
        x_ = slice_map(x, 0, step)
        y_ = slice_map(y, 0, step)
        w_ = slice_map(sample_weight, 0, step)

        with tf.GradientTape() as tape:

            y_pred = self(x_, training = True)  # Forward pass
            loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
            if isinstance(self.optimizer, lso.LossScaleOptimizer):
                loss = self.optimizer.get_scaled_loss(loss)

        gradients = tape.gradient(loss, self.trainable_variables)
        gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
        self.compiled_metrics.update_state(y_, y_pred)

        i = tf.constant(step)
        def cond(i, *args):
            return i < self._batch_size_

        def body(i, grad):
            x_ = slice_map(x, i, i + step)
            y_ = slice_map(y, i, i + step)
            w_ = slice_map(sample_weight, i, i + step)

            with tf.GradientTape() as tape:
                y_pred = self(x_, training = True) # Forward pass
                loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                if isinstance(self.optimizer, lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)

            _grad = tape.gradient(loss, self.trainable_variables)
            _grad = [_g * (1./self._grad_accum_) for _g in _grad]

            grad = [g + _g for g,_g in zip(grad, _grad)]

            self.compiled_metrics.update_state(y_, y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
        # --------------------------------------------------------------------------------------------------------------



        # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
        aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))

        if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
            gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))

        if isinstance(self.optimizer, lso.LossScaleOptimizer):
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients = self.optimizer._clip_gradients(gradients)
        if self.trainable_variables:
            if aggregate_grads_outside_optimizer:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
            else:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # --------------------------------------------------------------------------------------------------------------


        return {m.name: m.result() for m in self.metrics}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM