简体   繁体   中英

PyTorch one of the variables needed for gradient computation has been modified by an inplace operation

I'm doing a policy gradient method in PyTorch. I wanted to move the network update into the loop and it stopped working. I'm still a PyTorch newbie so sorry if the explanation is obvious.

Here is the original code that works:

self.policy.optimizer.zero_grad()
G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss += -g * logprob
                                 
loss.backward()
self.policy.optimizer.step()

And after the change:

G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss = -g * logprob
    self.policy.optimizer.zero_grad()
                                 
    loss.backward()
    self.policy.optimizer.step()

I get the error:

File "g:\VScode_projects\pytorch_shenanigans\policy_gradient.py", line 86, in learn
    loss.backward()
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 4]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I read that this RuntimeError often has to do with having to clone something, because we're using the same tensor to compute itself but I can't make heads of tails of what is wrong in my case.

This line, loss += -g * logprob , is what is wrong in your case.

Change it to this:

loss = loss + (-g * logprob)

And Yes, they are different. They perform the same operations but in different ways.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM