简体   繁体   中英

Inner workings of pytorch autograd.grad for inner derivatives

Consider the following code:

x = torch.tensor(2.0, requires_grad=True)

y = torch.square(x)
grad = autograd.grad(y, x)

x = x + grad[0]

y = torch.square(x)
grad2 = autograd.grad(y, x)

First, we have that ∇(x^2)=2x . In my understanding, grad2=∇((x + ∇(x^2))^2)=∇((x+2x)^2)=∇((3x)^2)=9∇x^2=18x . As expected, grad=4.0=2x , but grad2=12.0=6x , which I don't understand where it comes from. It feels as though the 3 comes from the expression I had, but it is not squared, and the 2 comes from the traditional derivative. Could somebody help me understand why this is happening? Furthermore, how far back does the computational graph that stores the gradients go?

Specifically, I am coming from a meta learning perspective, where one is interested in computing a quantity of the following form ∇ L(theta - alpha * ∇ L(theta))=(1 + ∇^2 L(theta)) ∇L(theta - alpha * ∇ L(theta) (here the derivative is with respect to theta ). Therefore, the computation, let's call it A , includes a second derivative. Computation is quite different than the following ∇_{theta - alpha ∇ L(theta)}L(\theta - alpha * ∇ L(theta))=∇_beta L(beta) , which I will call B .

Hopefully, it is clear how the snippet I had is related to what I described in the second paragraph. My overall question is: under what circumstances does pytorch realize computation A vs computation B when using autograd.grad ? I'd appreciate any explanation that goes into technical details about how this particular case is handled by autograd .

PD. The original code I was following made me wonder this is here ; in particular, lines 69 through 106, and subsequently line 193, which is when they use autograd.grad . For the code is even more unclear because they do a lot of model.clone() and so on.

If the question is unclear in any way, please let me know.

I made a few changes:

  1. I am not sure what torch.rand(2.0) is supposed to do. According to the text I simply set it to 2.
  2. An intermediate variable z is added so that we can compute gradient w.r.t. to the original variable. Yours is overwritten.
  3. Set create_graph=True to compute higher order gradients. See https://pytorch.org/docs/stable/generated/torch.autograd.grad.html
import torch
from torch import autograd

x = torch.ones(1, requires_grad=True)*2

y = torch.square(x)
grad = autograd.grad(y, x, create_graph=True)
z = x + grad[0]

y = torch.square(z)
grad2 = autograd.grad(y, x)
# yours is more like autograd.grad(y, z)

print(x)
print(grad)
print(grad2)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM