[英]Understanding when to call zero_grad() in pytorch, when training with multiple losses
I am going through an open-source implementation of a domain-adversarial model (GAN-like).我正在研究域对抗 model(类似 GAN)的开源实现。 The implementation uses pytorch and I am not sure they use
zero_grad()
correctly.该实现使用 pytorch 我不确定他们是否正确使用
zero_grad()
。 They call zero_grad()
for the encoder optimizer (aka the generator) before updating the discriminator loss.在更新鉴别器损失之前,他们为编码器优化器(又名生成器)调用
zero_grad()
。 However zero_grad()
is hardly documented, and I couldn't find information about it.然而
zero_grad()
几乎没有记录,我找不到有关它的信息。
Here is a psuedo code comparing a standard GAN training (option 1), with their implementation (option 2).这是一个伪代码,比较标准 GAN 训练(选项 1)和它们的实现(选项 2)。 I think the second option is wrong, because it may accumulate the D_loss gradients with the E_opt.
我认为第二个选项是错误的,因为它可能会用 E_opt 累积 D_loss 梯度。 Can someone tell if these two pieces of code are equivalent?
有人能判断这两条代码是否等价吗?
Option 1 (a standard GAN implementation):选项 1(标准 GAN 实现):
X, y = get_D_batch()
D_opt.zero_grad()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()
X, y = get_E_batch()
E_opt.zero_grad()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
Option 2 (calling zero_grad()
for both optimizers at the beginning):选项 2(在开始时为两个优化器调用
zero_grad()
):
E_opt.zero_grad()
D_opt.zero_grad()
X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()
X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
It depends on params
argument of torch.optim.Optimizer
subclasses (eg torch.optim.SGD
) and exact structure of the model.它取决于
torch.optim.Optimizer
子类(例如torch.optim.SGD
)的params
参数和 model 的确切结构。
Assuming E_opt
and D_opt
have different set of parameters ( model.encoder
and model.decoder
do not share weights), something like this:假设
E_opt
和D_opt
具有不同的参数集( model.encoder
和model.decoder
不共享权重),如下所示:
E_opt = torch.optim.Adam(model.encoder.parameters())
D_opt = torch.optim.Adam(model.decoder.parameters())
both options MIGHT indeed be equivalent (see commentary for your source code, additionally I have added backward()
which is really important here and also changed model
to discriminator
and generator
appropriately as I assume that's the case):这两个选项可能确实是等效的(请参阅您的源代码的注释,此外我还添加了
backward()
,这在这里非常重要,并且还适当地将model
更改为discriminator
器和generator
,因为我认为是这种情况):
# Starting with zero gradient
E_opt.zero_grad()
D_opt.zero_grad()
# See comment below for possible cases
X, y = get_D_batch()
pred = discriminator(x)
D_loss = loss(pred, y)
# This will accumulate gradients in discriminator only
# OR in discriminator and generator, depends on other parts of code
# See below for commentary
D_loss.backward()
# Correct weights of discriminator
D_opt.step()
# This only relies on random noise input so discriminator
# Is not part of this equation
X, y = get_E_batch()
pred = generator(x)
E_loss = loss(pred, y)
E_loss.backward()
# So only parameters of generator are updated always
E_opt.step()
Now it's all about get_D_Batch
feeding data to discriminator.现在一切都是关于
get_D_Batch
将数据提供给鉴别器。
This is not a problem as it does not involve generator, you pass real samples and only discriminator
takes part in this operation.这不是问题,因为它不涉及生成器,您传递真实样本并且只有
discriminator
器参与此操作。
Here indeed gradient accumulation may occur.这里确实可能发生梯度累积。 It would occur if
get_D_batch
would simply call X = generator(noise)
and passed this data to discriminator
.如果
get_D_batch
简单地调用X = generator(noise)
并将此数据传递给discriminator
,就会发生这种情况。 In such case both discriminator
and generator
have their gradients accumulated during backward()
as both are used.在这种情况下,
discriminator
器和generator
都在使用backward()
期间累积它们的梯度。
We should take generator
out of the equation.我们应该将
generator
排除在等式之外。 Taken from PyTorch DCGan example there is a little line like this:取自PyTorch DCGan 示例,有这样一行:
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# DETACH HERE
output = discriminator(fake.detach()).view(-1)
What detach
does is it "stops" the gradient by detach
ing it from the computational graph. detach
所做的是通过将梯度从计算图中detach
来“停止”梯度。 So gradients will not be backpropagated along this variable.所以梯度不会沿着这个变量反向传播。 This effectively does not impact gradients of
generator
so it has no more gradients so no accumulation happens.这实际上不会影响
generator
的梯度,因此它没有更多的梯度,因此不会发生累积。
Another way (IMO better) would be to use with.torch.no_grad():
block like this:另一种方法(IMO 更好)是使用
with.torch.no_grad():
这样的块:
# Generate fake image batch with G
with torch.no_grad():
fake = generator(noise)
label.fill_(fake_label)
# NO DETACH NEEDED
output = discriminator(fake).view(-1)
This way generator
operations will not build part of the graph so we get better performance (it would in the first case but would be detached afterwards).这样
generator
操作不会构建图的一部分,因此我们可以获得更好的性能(在第一种情况下会,但之后会被分离)。
Yeah, all in all first option is better for standard GANs as one doesn't have to think about such stuff (people implementing it should, but readers should not).是的,总而言之,第一个选项对于标准 GAN 来说更好,因为人们不必考虑这些东西(实施它的人应该考虑,但读者不应该考虑)。 Though there are also other approaches like single optimizer for both
generator
and discriminator
(one cannot zero_grad()
only for subset of parameters (eg encoder
) in this case), weight sharing and others which further clutter the picture.尽管还有其他方法,例如针对
generator
和discriminator
器的单一优化器(在这种情况下,不能仅针对参数子集(例如encoder
)进行zero_grad()
)、权重共享和其他进一步使图片混乱的方法。
with torch.no_grad()
should alleviate the problem in all/most cases as far as I can tell and imagine ATM.就我所知和想象的 ATM 而言
with torch.no_grad()
应该可以在所有/大多数情况下缓解问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.