[英]How to break PyTorch autograd with in-place ops
I'm trying to understand better the role of in-place operations in PyTorch autograd.我试图更好地理解就地操作在 PyTorch autograd 中的作用。 My understanding is that they are likely to cause problems since they may overwrite values needed during the backward step.
我的理解是它们可能会导致问题,因为它们可能会覆盖后退步骤所需的值。
I'm trying to build an example where an in-place operation breaks the auto differentiation, my idea is to overwrite some value needed during the backpropagation after it has been used to compute some other tensor.我正在尝试构建一个示例,其中就地操作破坏了自动微分,我的想法是在使用它来计算其他张量后覆盖反向传播期间所需的一些值。
I'm using the assignment as the in-place operation (I tried +=
with the same result), I double-checked it is an in-place op in this way:我使用分配作为就地操作(我尝试了
+=
并得到了相同的结果),我以这种方式仔细检查了它是一个就地操作:
x = torch.arange(5, dtype=torch.float, requires_grad=True)
y = x
y[3] = -1
print(x)
prints:印刷:
tensor([ 0., 1., 2., -1., 4.], grad_fn=<CopySlices>)
This is my attempt to break autograd:这是我打破 autograd 的尝试:
x = torch.arange(5, dtype=torch.float, requires_grad=True)
out1 = x ** 2
out2 = out1 / 10
# out1[3] += 100
out2.sum().backward()
print(x.grad)
This prints这打印
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])
x = torch.arange(5, dtype=torch.float, requires_grad=True)
out1 = x ** 2
out2 = out1 / 10
out1[3] = 0
out2.sum().backward()
print(x.grad)
This prints:这打印:
tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000])
I was expecting to obtain differents grads.我期待获得不同的毕业生。
grad_fn=<CopySlices>
.grad_fn=<CopySlices>
。A working example of an in-place operations that breaks autograd:破坏 autograd 的就地操作的工作示例:
x = torch.ones(5, requires_grad=True)
x2 = (x + 1).sqrt()
z = (x2 - 10)
x2[0] = -1
z.sum().backward()
Raises:提高:
RuntimeError: one of the variables needed for gradient computation has been modified by an in-place operation: [torch.FloatTensor [5]], which is output 0 of SqrtBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.