[英]Accumulate gradients with distributed strategy in Tensorflow 2
I have implemented a distributed strategy to train my model on multiple GPUs.我已经实施了一种分布式策略来在多个 GPU 上训练我的 model。
strategy = tf.distribute.MirroredStrategy(devices=devices[:FLAGS.n_gpus])
strategy.run(fn=self.train_step, args=(model, data))
My model now got more complex and bigger and I had to reduce the batch size to fit it onto the GPUs.我的 model 现在变得更复杂、更大,我不得不减小批量大小以适应 GPU。 The gradient is quite noisy now and I want to increase the batch size again by accumulating gradients.
梯度现在非常嘈杂,我想通过累积梯度再次增加批量大小。
Now my question is: is this even possible when using a mirrored strategy?现在我的问题是:使用镜像策略时这是否可能? I know that loss and gradients are combined across the replicas anyway, so is there a way to sum them across the replicas AND eg a loop running over the batches?
我知道损失和梯度无论如何都是在副本中组合的,那么有没有办法在副本中对它们求和,例如在批次上运行的循环? I tried the straight-forward thing and returned the per replica calculated gradients to add and apply them outside the
strategy.run()
like that:我尝试了直截了当的方法并返回了每个副本计算的梯度,以便在
strategy.run()
之外添加和应用它们,如下所示:
for b in batches:
per_replica_gradients = strategy.run(fn=self.train_step, args=(model, data))
total_gradient += per_replica_gradients
optimizer.apply_gradients(zip(total_gradient, model.trainable_variables)
but Tensorflow tells me that this is not possible and the gradients have to be applied withing the strategy.run()
.但是 Tensorflow 告诉我这是不可能的,必须使用
strategy.run()
应用渐变。 This also makes sense to me but I wonder whether there is a possibility to accumulate gradients AND use a mirrored strategy?这对我来说也很有意义,但我想知道是否有可能积累梯度并使用镜像策略?
You could use tf.distribute.ReplicaContext.all_reduce
: This differs from Strategy.reduce
in that it is for replica context and does not copy the results to the host device.您可以使用
tf.distribute.ReplicaContext.all_reduce
:这与Strategy.reduce
的不同之处在于它用于副本上下文并且不会将结果复制到主机设备。 all_reduce should be typically used for reductions inside the training step such as gradients. all_reduce 通常应用于训练步骤中的减少,例如梯度。
More details can be found in the document here .可以在此处的文档中找到更多详细信息。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.