简体   繁体   中英

How to break PyTorch autograd with in-place ops

I'm trying to understand better the role of in-place operations in PyTorch autograd. My understanding is that they are likely to cause problems since they may overwrite values needed during the backward step.

I'm trying to build an example where an in-place operation breaks the auto differentiation, my idea is to overwrite some value needed during the backpropagation after it has been used to compute some other tensor.

I'm using the assignment as the in-place operation (I tried += with the same result), I double-checked it is an in-place op in this way:

x = torch.arange(5, dtype=torch.float, requires_grad=True)
y = x
y[3] = -1
print(x)

prints:

tensor([ 0.,  1.,  2., -1.,  4.], grad_fn=<CopySlices>)

This is my attempt to break autograd:

  1. Without the in-place op:
x = torch.arange(5, dtype=torch.float, requires_grad=True)
out1 = x ** 2
out2 = out1 / 10
# out1[3] += 100  
out2.sum().backward()
print(x.grad)

This prints

tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])
  1. With the in-place op:
x = torch.arange(5, dtype=torch.float, requires_grad=True)
out1 = x ** 2
out2 = out1 / 10
out1[3] = 0  
out2.sum().backward()
print(x.grad)

This prints:

tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])

I was expecting to obtain differents grads.

  • What is the item assignment doing? I don't get the grad_fn=<CopySlices> .
  • Why does it return the same grads?
  • Is there a working example of in-place operations that break autograd?
  • Is there a list of non-backwards compatible PyTorch ops?

A working example of an in-place operations that breaks autograd:

  x = torch.ones(5, requires_grad=True)
  x2 = (x + 1).sqrt()
  z = (x2 - 10)
  x2[0] = -1
  z.sum().backward()

Raises:

RuntimeError: one of the variables needed for gradient computation has been modified by an in-place operation: [torch.FloatTensor [5]], which is output 0 of SqrtBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM