简体   繁体   English

在多损失训练时了解何时在 pytorch 中调用 zero_grad()

[英]Understanding when to call zero_grad() in pytorch, when training with multiple losses

I am going through an open-source implementation of a domain-adversarial model (GAN-like).我正在研究域对抗 model(类似 GAN)的开源实现。 The implementation uses pytorch and I am not sure they use zero_grad() correctly.该实现使用 pytorch 我不确定他们是否正确使用zero_grad() They call zero_grad() for the encoder optimizer (aka the generator) before updating the discriminator loss.在更新鉴别器损失之前,他们为编码器优化器(又名生成器)调用zero_grad() However zero_grad() is hardly documented, and I couldn't find information about it.然而zero_grad()几乎没有记录,我找不到有关它的信息。

Here is a psuedo code comparing a standard GAN training (option 1), with their implementation (option 2).这是一个伪代码,比较标准 GAN 训练(选项 1)和它们的实现(选项 2)。 I think the second option is wrong, because it may accumulate the D_loss gradients with the E_opt.我认为第二个选项是错误的,因为它可能会用 E_opt 累积 D_loss 梯度。 Can someone tell if these two pieces of code are equivalent?有人能判断这两条代码是否等价吗?

Option 1 (a standard GAN implementation):选项 1(标准 GAN 实现):

X, y = get_D_batch()
D_opt.zero_grad()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()

X, y = get_E_batch()
E_opt.zero_grad()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()

Option 2 (calling zero_grad() for both optimizers at the beginning):选项 2(在开始时为两个优化器调用zero_grad() ):

E_opt.zero_grad()
D_opt.zero_grad()

X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()

X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()

It depends on params argument of torch.optim.Optimizer subclasses (eg torch.optim.SGD ) and exact structure of the model.它取决于torch.optim.Optimizer子类(例如torch.optim.SGD )的params参数和 model 的确切结构。

Assuming E_opt and D_opt have different set of parameters ( model.encoder and model.decoder do not share weights), something like this:假设E_optD_opt具有不同的参数集( model.encodermodel.decoder不共享权重),如下所示:

E_opt = torch.optim.Adam(model.encoder.parameters())
D_opt = torch.optim.Adam(model.decoder.parameters())

both options MIGHT indeed be equivalent (see commentary for your source code, additionally I have added backward() which is really important here and also changed model to discriminator and generator appropriately as I assume that's the case):这两个选项可能确实是等效的(请参阅您的源代码的注释,此外我还添加了backward() ,这在这里非常重要,并且还适当地将model更改为discriminator器和generator ,因为我认为是这种情况):

# Starting with zero gradient
E_opt.zero_grad()
D_opt.zero_grad()

# See comment below for possible cases
X, y = get_D_batch()
pred = discriminator(x)
D_loss = loss(pred, y)
# This will accumulate gradients in discriminator only
# OR in discriminator and generator, depends on other parts of code
# See below for commentary
D_loss.backward()
# Correct weights of discriminator
D_opt.step()

# This only relies on random noise input so discriminator
# Is not part of this equation
X, y = get_E_batch()
pred = generator(x)
E_loss = loss(pred, y)
E_loss.backward()
# So only parameters of generator are updated always
E_opt.step()

Now it's all about get_D_Batch feeding data to discriminator.现在一切都是关于get_D_Batch将数据提供给鉴别器。

Case 1 - real samples案例 1 - 真实样品

This is not a problem as it does not involve generator, you pass real samples and only discriminator takes part in this operation.这不是问题,因为它不涉及生成器,您传递真实样本并且只有discriminator器参与此操作。

Case 2 - generated samples案例 2 - 生成的样本

Naive case幼稚案例

Here indeed gradient accumulation may occur.这里确实可能发生梯度累积。 It would occur if get_D_batch would simply call X = generator(noise) and passed this data to discriminator .如果get_D_batch简单地调用X = generator(noise)并将此数据传递给discriminator ,就会发生这种情况。 In such case both discriminator and generator have their gradients accumulated during backward() as both are used.在这种情况下, discriminator器和generator都在使用backward()期间累积它们的梯度。

Correct case正确大小写

We should take generator out of the equation.我们应该将generator排除在等式之外。 Taken from PyTorch DCGan example there is a little line like this:取自PyTorch DCGan 示例,有这样一行:

# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# DETACH HERE
output = discriminator(fake.detach()).view(-1)

What detach does is it "stops" the gradient by detach ing it from the computational graph. detach所做的是通过将梯度从计算图中detach来“停止”梯度。 So gradients will not be backpropagated along this variable.所以梯度不会沿着这个变量反向传播。 This effectively does not impact gradients of generator so it has no more gradients so no accumulation happens.这实际上不会影响generator的梯度,因此它没有更多的梯度,因此不会发生累积。

Another way (IMO better) would be to use with.torch.no_grad(): block like this:另一种方法(IMO 更好)是使用with.torch.no_grad():这样的块:

# Generate fake image batch with G
with torch.no_grad():
    fake = generator(noise)
label.fill_(fake_label)
# NO DETACH NEEDED
output = discriminator(fake).view(-1)

This way generator operations will not build part of the graph so we get better performance (it would in the first case but would be detached afterwards).这样generator操作不会构建图的一部分,因此我们可以获得更好的性能(在第一种情况下会,但之后会被分离)。

Finally最后

Yeah, all in all first option is better for standard GANs as one doesn't have to think about such stuff (people implementing it should, but readers should not).是的,总而言之,第一个选项对于标准 GAN 来说更好,因为人们不必考虑这些东西(实施它的人应该考虑,但读者不应该考虑)。 Though there are also other approaches like single optimizer for both generator and discriminator (one cannot zero_grad() only for subset of parameters (eg encoder ) in this case), weight sharing and others which further clutter the picture.尽管还有其他方法,例如针对generatordiscriminator器的单一优化器(在这种情况下,不能仅针对参数子集(例如encoder )进行zero_grad() )、权重共享和其他进一步使图片混乱的方法。

with torch.no_grad() should alleviate the problem in all/most cases as far as I can tell and imagine ATM.就我所知和想象的 ATM 而言with torch.no_grad()应该可以在所有/大多数情况下缓解问题。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Pytorch:当优化器必须调用zero_grad()来帮助积累梯度时,带有动量的SGD如何工作? - Pytorch: How does SGD with momentum works when optimizer has to call zero_grad() to help accumulation of gradients? 为什么我们需要在 PyTorch 中调用 zero_grad()? - Why do we need to call zero_grad() in PyTorch? 使用 PyTorch DistributedDataParallel 在多个节点上训练时进程卡住 - Process stuck when training on multiple nodes using PyTorch DistributedDataParallel 了解 Pytorch 中的模型训练和评估 - Understanding model training and evaluation in Pytorch 使用pytorch错误训练RNN:RuntimeError:张量的元素0不需要grad并且没有grad_fn - Error training RNN with pytorch : RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 在多次损失 keras 上用单个 output 训练 model - Training a model with single output on multiple losses keras 了解何时在Pytorch中使用python列表 - Understanding when to use python list in Pytorch 使用 pytorch 进行训练时,调试器挂起,即使运行正常 - When training with pytorch, debugger hangs, even though running works fine 在PyTorch中训练神经网络时,损失总是“为” - Loss is 'nan' all the time when training the neural network in PyTorch 使用pytorch训练神经网络时损失的周期性模式 - Periodic pattern in loss when training neural networks using pytorch
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM