I'm trying to implement Gradient accumulation on TF2.x. All implementations I've found are either for TF1.x or for the old keras interface. I don't think there is an implementation out there (though I'd be very happy to be proven wrong on this).
Here's what I'm working with:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt
class SimpleTrainStepModel(Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x, training = True) # Forward pass
loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y, y_pred)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {m.name: m.result() for m in self.metrics}
class GradAccumModel(Model):
def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
self.train_function = None
if batch_size % grad_accum != 0:
raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
self.grad_accum = grad_accum
self.batch_size = batch_size
return super(GradAccumModel, self).fit(*args,
batch_size = self.batch_size,
#validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
**kwargs)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
step = self.batch_size // self.grad_accum
# def _slice_nested(obj, i, j):
# if type(obj) is list:
# return [o[i:j] for o in obj]
# else:
# return obj[i:j]
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x[:step], training = True) # Forward pass
loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y[:step], y_pred)
i = tf.constant(step)
# tf.print('TF - HERE!')
def cond(i, *args):
return i < self.batch_size
def body(i, grad):
# tf.print('\tTF - HERE!')
with tf.GradientTape() as tape:
y_pred = self(x[i:i + step], training = True) # Forward pass
loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
_grad = tape.gradient(loss, self.trainable_variables)
for g,_g in zip(grad, _grad):
g += _g
self.compiled_metrics.update_state(y[i:i + step], y_pred)
return [i + step, grad]
i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
# for g in gradients: # I tried with and without division co calculate the mean
# g *= 1/self.grad_accum #
# Update weights
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
if __name__ == '__main__':
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()
for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
[{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
['blue', 'green', 'yellow', 'red'])):
for _ in tqdm(range(10)):
# tf.random.set_seed(0)
x = Input((28, 28))
y = x
y = Flatten()(y)
y = Dense(128, activation = 'sigmoid')(y)
y = Dense(10, activation = 'softmax')(y)
model = MODEL(x, y)
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer = tf.keras.optimizers.Adam(1e-4),
metrics = ['acc'])
hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
plt.plot(hist.history['val_acc'], color = colour, alpha = .25)
plt.title('')
plt.xscale('symlog')
plt.yscale('logit')
plt.show()
I've been able to verify that it does actually save gpu memory. However, the end result is not the same as the normal Model.fit
.
As you can see, the first three Model.fit
s are well clustered and give the same results. But when the the while
cycle comes into play the training is quite different.
Anyone have any idea why this is happening?
After a lot more attempts I found the solution, It seems that the main problem was the compound assignments of the gradients which don't work quite as I was expecting. Here is my final solution for anyone who might be interested. It includes the extra stuff for distributed, mixed precision trainings, and nested input/output.
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model
class Model(_Model):
def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
"""
Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.
Parameters
----------
batch_size : int
same as in Model.fit
grad_accum_steps : int
Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
"""
if grad_accum_steps == 1:
super().fit(*args, batch_size = batch_size, **kwargs)
self.train_function = None
num_workers = ds_context.get_strategy().num_replicas_in_sync
if batch_size % (grad_accum_steps * num_workers) != 0:
raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')
self._grad_accum_ = grad_accum_steps
self._batch_size_ = batch_size
self._num_workers_ = num_workers
train_step_backup = self.train_step
self.train_step = self._train_step_
out = super(self).fit(*args,
batch_size = self._batch_size_, # TODO maybe consider validation batch size
**kwargs)
del self._grad_accum_
del self._batch_size_
del self._num_workers_
self.train_step = train_step_backup
return out
def _train_step_(self, data):
"""
Custom training step taking into account gradient accumulation for low memory training
"""
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
def slice_map(struct, start, stop): # dealing with nasty nested structures
if struct is None:
return None # special case for sample_weight
return nest.map_structure(lambda x: x[start:stop], struct)
# ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
step = self._batch_size_ // self._num_workers_ // self._grad_accum_
x_ = slice_map(x, 0, step)
y_ = slice_map(y, 0, step)
w_ = slice_map(sample_weight, 0, step)
with tf.GradientTape() as tape:
y_pred = self(x_, training = True) # Forward pass
loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
if isinstance(self.optimizer, lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
gradients = tape.gradient(loss, self.trainable_variables)
gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
self.compiled_metrics.update_state(y_, y_pred)
i = tf.constant(step)
def cond(i, *args):
return i < self._batch_size_
def body(i, grad):
x_ = slice_map(x, i, i + step)
y_ = slice_map(y, i, i + step)
w_ = slice_map(sample_weight, i, i + step)
with tf.GradientTape() as tape:
y_pred = self(x_, training = True) # Forward pass
loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
if isinstance(self.optimizer, lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
_grad = tape.gradient(loss, self.trainable_variables)
_grad = [_g * (1./self._grad_accum_) for _g in _grad]
grad = [g + _g for g,_g in zip(grad, _grad)]
self.compiled_metrics.update_state(y_, y_pred)
return [i + step, grad]
i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
# --------------------------------------------------------------------------------------------------------------
# ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))
if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))
if isinstance(self.optimizer, lso.LossScaleOptimizer):
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = self.optimizer._clip_gradients(gradients)
if self.trainable_variables:
if aggregate_grads_outside_optimizer:
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
else:
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# --------------------------------------------------------------------------------------------------------------
return {m.name: m.result() for m in self.metrics}
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.