简体   繁体   中英

Pytorch Autograd gives different gradients when using .clamp instead of torch.relu

I'm still working on my understanding of the PyTorch autograd system. One thing I'm struggling at is to understand why .clamp(min=0) and nn.functional.relu() seem to have different backward passes.

It's especially confusing as .clamp is used equivalently to relu in PyTorch tutorials, such as https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn .

I found this when analysing the gradients of a simple fully connected net with one hidden layer and a relu activation (linear in the outputlayer).

to my understanding the output of the following code should be just zeros. I hope someone can show me what I am missing.

import torch
dtype = torch.float

x = torch.tensor([[3,2,1],
                  [1,0,2],
                  [4,1,2],
                  [0,0,1]], dtype=dtype)

y = torch.ones(4,4)

w1_a = torch.tensor([[1,2],
                     [0,1],
                     [4,0]], dtype=dtype, requires_grad=True)
w1_b = w1_a.clone().detach()
w1_b.requires_grad = True



w2_a = torch.tensor([[-1, 1],
                     [-2, 3]], dtype=dtype, requires_grad=True)
w2_b = w2_a.clone().detach()
w2_b.requires_grad = True


y_hat_a = torch.nn.functional.relu(x.mm(w1_a)).mm(w2_a)
y_a = torch.ones_like(y_hat_a)
y_hat_b = x.mm(w1_b).clamp(min=0).mm(w2_b)
y_b = torch.ones_like(y_hat_b)

loss_a = (y_hat_a - y_a).pow(2).sum()
loss_b = (y_hat_b - y_b).pow(2).sum()

loss_a.backward()
loss_b.backward()

print(w1_a.grad - w1_b.grad)
print(w2_a.grad - w2_b.grad)

# OUT:
# tensor([[  0.,   0.],
#         [  0.,   0.],
#         [  0., -38.]])
# tensor([[0., 0.],
#         [0., 0.]])
# 

The reason is that clamp and relu produce different gradients at 0 . Checking with a scalar tensor x = 0 the two versions: (x.clamp(min=0) - 1.0).pow(2).backward() versus (relu(x) - 1.0).pow(2).backward() . The resulting x.grad is 0 for the relu version but it is -2 for the clamp version. That means relu chooses x == 0 --> grad = 0 while clamp chooses x == 0 --> grad = 1 .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM