简体   繁体   中英

How to calculate Gradient of the loss with respect to input?

I have a pre-trained PyTorch model. I need to calculate the gradient of the loss with respect to the network's inputs using this model (without training again and only using the pre-trained model).

I wrote the following code, but I am not sure it is correct or not.

    test_X, test_y = load_data(mode='test')
    testset_original = MyDataset(test_X, test_y, transform=default_transform)
    testloader = DataLoader(testset_original, batch_size=32, shuffle=True)

    model = MyModel(device=device).to(device)
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])

    gradient_losses = []
    for i, data in enumerate(testloader):
        inputs, labels = data
        inputs= inputs.to(device)
        labels = labels.to(device)
        inputs.requires_grad = True
        output = model(inputs)
        loss = loss_function(output)
        loss.backward()
        gradient_losses.append(inputs.grad)

My question is, does this list gradient_losses actually storing what I wish to store? If not, what is the correct way to do that?

does this list gradient_losses actually storing what I wish to store?

Yes, if you are looking to get the derivative of the loss with respect to the input then that seems to be the correct way to do it. Here is minimal example, take f(x) = a*x . Then df/dx = a .

>>> x = torch.rand(10, requires_grad=True)
>>> y = torch.rand(10)
>>> a = torch.tensor([3.], requires_grad=True)

>>> loss = a*x - y
>>> loss.mean().backward()

>>> x.grad
tensor([0.3000, 0.3000, ..., 0.3000, 0.3000])

Which, in this case is equal to a / len(x)

Do note, each gradient you extract with input.grad will be averaged over the whole batch, and won't be a gradient over each individual input.


Also, you don't need to .clone() your input gradients as they are not part of the model and won't get zeroed by model.zero_grad() .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM