[英]Gradient accumulation in tensorflow 2.x / keras
I'm trying to implement Gradient accumulation on TF2.x.我正在尝试在 TF2.x 上实现梯度累积。 All implementations I've found are either for TF1.x or for the old keras interface.
我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。 I don't think there is an implementation out there (though I'd be very happy to be proven wrong on this).
我认为那里没有实现(尽管我很高兴被证明是错误的)。
Here's what I'm working with:这是我正在使用的:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt
class SimpleTrainStepModel(Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x, training = True) # Forward pass
loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y, y_pred)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
return {m.name: m.result() for m in self.metrics}
class GradAccumModel(Model):
def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
self.train_function = None
if batch_size % grad_accum != 0:
raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
self.grad_accum = grad_accum
self.batch_size = batch_size
return super(GradAccumModel, self).fit(*args,
batch_size = self.batch_size,
#validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,
**kwargs)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
step = self.batch_size // self.grad_accum
# def _slice_nested(obj, i, j):
# if type(obj) is list:
# return [o[i:j] for o in obj]
# else:
# return obj[i:j]
# FIRST GRADIENT
with tf.GradientTape() as tape:
y_pred = self(x[:step], training = True) # Forward pass
loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.compiled_metrics.update_state(y[:step], y_pred)
i = tf.constant(step)
# tf.print('TF - HERE!')
def cond(i, *args):
return i < self.batch_size
def body(i, grad):
# tf.print('\tTF - HERE!')
with tf.GradientTape() as tape:
y_pred = self(x[i:i + step], training = True) # Forward pass
loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
_grad = tape.gradient(loss, self.trainable_variables)
for g,_g in zip(grad, _grad):
g += _g
self.compiled_metrics.update_state(y[i:i + step], y_pred)
return [i + step, grad]
i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
# for g in gradients: # I tried with and without division co calculate the mean
# g *= 1/self.grad_accum #
# Update weights
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
if __name__ == '__main__':
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()
for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
[{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
['blue', 'green', 'yellow', 'red'])):
for _ in tqdm(range(10)):
# tf.random.set_seed(0)
x = Input((28, 28))
y = x
y = Flatten()(y)
y = Dense(128, activation = 'sigmoid')(y)
y = Dense(10, activation = 'softmax')(y)
model = MODEL(x, y)
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer = tf.keras.optimizers.Adam(1e-4),
metrics = ['acc'])
hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
plt.plot(hist.history['val_acc'], color = colour, alpha = .25)
plt.title('')
plt.xscale('symlog')
plt.yscale('logit')
plt.show()
I've been able to verify that it does actually save gpu memory.我已经能够验证它确实保存了 gpu memory。 However, the end result is not the same as the normal
Model.fit
.但是,最终结果与正常的
Model.fit
。
As you can see, the first three Model.fit
s are well clustered and give the same results.如您所见,前三个
Model.fit
聚类良好并给出相同的结果。 But when the the while
cycle comes into play the training is quite different.但是当
while
循环开始发挥作用时,训练就完全不同了。
Anyone have any idea why this is happening?有人知道为什么会这样吗?
After a lot more attempts I found the solution, It seems that the main problem was the compound assignments of the gradients which don't work quite as I was expecting.经过多次尝试,我找到了解决方案,似乎主要问题是渐变的复合分配,它不像我预期的那样工作。 Here is my final solution for anyone who might be interested.
对于任何可能感兴趣的人,这是我的最终解决方案。 It includes the extra stuff for distributed, mixed precision trainings, and nested input/output.
它包括用于分布式、混合精度训练和嵌套输入/输出的额外内容。
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model
class Model(_Model):
def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
"""
Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.
Parameters
----------
batch_size : int
same as in Model.fit
grad_accum_steps : int
Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
"""
if grad_accum_steps == 1:
super().fit(*args, batch_size = batch_size, **kwargs)
self.train_function = None
num_workers = ds_context.get_strategy().num_replicas_in_sync
if batch_size % (grad_accum_steps * num_workers) != 0:
raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')
self._grad_accum_ = grad_accum_steps
self._batch_size_ = batch_size
self._num_workers_ = num_workers
train_step_backup = self.train_step
self.train_step = self._train_step_
out = super(self).fit(*args,
batch_size = self._batch_size_, # TODO maybe consider validation batch size
**kwargs)
del self._grad_accum_
del self._batch_size_
del self._num_workers_
self.train_step = train_step_backup
return out
def _train_step_(self, data):
"""
Custom training step taking into account gradient accumulation for low memory training
"""
if len(data) == 3:
x, y, sample_weight = data
else:
(x, y), sample_weight = data, None
def slice_map(struct, start, stop): # dealing with nasty nested structures
if struct is None:
return None # special case for sample_weight
return nest.map_structure(lambda x: x[start:stop], struct)
# ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
step = self._batch_size_ // self._num_workers_ // self._grad_accum_
x_ = slice_map(x, 0, step)
y_ = slice_map(y, 0, step)
w_ = slice_map(sample_weight, 0, step)
with tf.GradientTape() as tape:
y_pred = self(x_, training = True) # Forward pass
loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
if isinstance(self.optimizer, lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
gradients = tape.gradient(loss, self.trainable_variables)
gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
self.compiled_metrics.update_state(y_, y_pred)
i = tf.constant(step)
def cond(i, *args):
return i < self._batch_size_
def body(i, grad):
x_ = slice_map(x, i, i + step)
y_ = slice_map(y, i, i + step)
w_ = slice_map(sample_weight, i, i + step)
with tf.GradientTape() as tape:
y_pred = self(x_, training = True) # Forward pass
loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
if isinstance(self.optimizer, lso.LossScaleOptimizer):
loss = self.optimizer.get_scaled_loss(loss)
_grad = tape.gradient(loss, self.trainable_variables)
_grad = [_g * (1./self._grad_accum_) for _g in _grad]
grad = [g + _g for g,_g in zip(grad, _grad)]
self.compiled_metrics.update_state(y_, y_pred)
return [i + step, grad]
i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
# --------------------------------------------------------------------------------------------------------------
# ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))
if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))
if isinstance(self.optimizer, lso.LossScaleOptimizer):
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = self.optimizer._clip_gradients(gradients)
if self.trainable_variables:
if aggregate_grads_outside_optimizer:
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
else:
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# --------------------------------------------------------------------------------------------------------------
return {m.name: m.result() for m in self.metrics}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.