简体   繁体   English

Pytorch:如何访问 CrossEntropyLoss() 梯度?

[英]Pytorch: How to access CrossEntropyLoss() gradient?

I want to modify the tensor that stores the CrossEntropyLoss() gradient, that is, P(i)-T(i).我想修改存储 CrossEntropyLoss() 梯度的张量,即 P(i)-T(i)。 Where is it stored and how do I access it?它存储在哪里以及如何访问它?

code:代码:

input = torch.randn(3, 5, requires_grad=True)
input.register_hook(lambda x: print(" \n input hook: ",x))
print(input)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target)

criterion = nn.CrossEntropyLoss()
criterion.requires_grad = True
loss0 = criterion(input,target)
loss0.register_hook(lambda x: print(" \n loss0 hook: ",x))
print("before backward loss0.grad :",loss0.grad)
print("loss0 :",loss0)
loss0.backward()
print("after backward loss0.grad :",loss0.grad)

output:输出:

tensor([[-0.6149, -0.8179,  0.6084, -0.2837, -0.5316],
        [ 1.7246,  0.5348,  1.3646, -0.7148, -0.3421],
        [-0.3478, -0.6732, -0.7610, -1.0381, -0.5570]], requires_grad=True)
tensor([4, 1, 0])
before backward loss0.grad : None
loss0 : tensor(1.7500, grad_fn=<NllLossBackward>)

 loss0 hook:  tensor(1.)

 input hook:  tensor([[ 0.0433,  0.0354,  0.1472,  0.0603, -0.2862],
        [ 0.1504, -0.2876,  0.1050,  0.0131,  0.0190],
        [-0.2432,  0.0651,  0.0597,  0.0452,  0.0732]])
after backward loss0.grad : None

Given your specification in the comments you want the gradient with respect to the input (output of the model), in your code you look at the gradient of the loss which does not exist.鉴于您在评论中的规范,您希望相对于输入(模型的输出)的梯度,在您的代码中,您会看到不存在的损失梯度。 So you could so something like:所以你可以这样:

import torch
input = torch.tensor([1,0,1,0], dtype=float, requires_grad=True)
target = torch.tensor([1,2,3,4], dtype=float)
loss = (input - target).abs().mean()
loss.backward()

Here loss.grad gives you None , but input.grad returns:这里loss.grad给你None ,但input.grad返回:

tensor([ 0.0000, -0.2500, -0.2500, -0.2500], dtype=torch.float64)

Which should be the gradient you are interested in.这应该是您感兴趣的渐变。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM