'Why detach needs to be called on variable in this example?
I was going through this example - https://github.com/pytorch/examples/blob/master/dcgan/main.py and I have a basic question.
fake = netG(noise)
label = Variable(label.fill_(fake_label))
output = netD(fake.detach()) # detach to avoid training G on these labels
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
I understand that why we call detach()
on variable fake
, so that no gradients are computed for the Generator parameters. My question is, does it matter since optimizerD.step()
is going to update the parameters associated with Discriminator only?
OptimizerD is defined as:
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
Besides, in the next step when we will update parameters for Generator, before that we will call netG.zero_grad()
which eventually removes all previously computed gradients. Moreover, when we update parameters for G network, we do this - output = netD(fake)
. Here, we are not using detach. Why?
So, why detaching the variable (line 3) is necessary in the above code?
Solution 1:[1]
ORIGINAL ANSWER (WRONG / INCOMPLETE)
You're right, optimizerD
only updates netD
and the gradients on netG
are not used before netG.zero_grad()
is called, so detaching is not necessary, it just saves time, because you're not computing gradients for the generator.
You're basically also answering your other question yourself, you don't detach fake
in the second block because you specifically want to compute gradients on netG
to be able to update its parameters.
Note how in the second block real_label
is used as the corresponding label for fake
, so if the discriminator finds the fake input to be real, the final loss is small, and vice versa, which is precisely what you want for the generator. Not sure if that's what confused you, but it's really the only difference compared to training the discriminator on fake inputs.
EDIT
Please see FatPanda's comment! My original answer is in fact incorrect. Pytorch destroys (parts of) the compute graph when .backward()
is called. Without detaching before errD_fake.backward()
the errG.backward()
call later would not be able to backprop into the generator because the required graph is no longer available (unless you specify retain_graph=True
). I'm relieved Soumith made the same mistake :D
Solution 2:[2]
The top voted answer is INCORRECT/INCOMPLETE!
Check this: https://github.com/pytorch/examples/issues/116, and have a look at @plopd's answer:
This is not true. Detaching
fake
from the graph is necessary to avoid forward-passing the noise through G when we actually update the generator. If we do not detach, then, althoughfake
is not needed for gradient update of D, it will still be added to the computational graph and as a consequence ofbackward
pass which clears all the variables in the graph (retain_graph=False
by default),fake
won't be available when G is updated.
This post also clarifies a lot: https://zhuanlan.zhihu.com/p/43843694 (In Chinese).
Solution 3:[3]
Because the fake variable is now part of the Generator graph [1], but you don't want that. So you have to "detach" it from him, before you are putting it into the Discriminator.
Solution 4:[4]
that's because if you don't use fake.detach()
in output = netD(fake.detach()).view(-1)
then fake is just some middle variable in the whole computational Graph, which tracks gradients in both netG and netD. and when you call netD.backward()
the graph will be released. which means no more gradient information about netG() in the computational Graph. then when you use errG.backward() later, it will cause an error something like
Trying to backward through the graph a second time
if you don't use fake.detach(), you can use netD.backward(retain_graph=True)
Solution 5:[5]
let me tell you. The role of detach is to freeze the gradient drop. Whether it is for discriminating the network or generating the network, we update all about logD(G(z)). For the discriminant network, freezing G does not affect the overall gradient update (that is The inner function is considered to be a constant, which does not affect the outer function to find the gradient), but conversely, if D is frozen, there is no way to complete the gradient update. Therefore, we did not use the gradient of freezing D when training the generator. So, for the generator, we did calculate the gradient of D, but we didn't update the weight of D (only optimizer_g.step was written), so the discriminator will not be changed when the generator is trained. You may ask, that's why, when you train the discriminator, you need to add detach. Isn't this an extra move? Because we freeze the gradient, we can speed up the training, so we can use it where it can be used. It is not an extra task. Then when we train the generator, because of logD(G(z)), there is no way to freeze the gradient of D, so we will not write detach here.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | |
Solution 2 | |
Solution 3 | loose11 |
Solution 4 | batman47steam |
Solution 5 | einstellung |