I have a neural network which outputs output
. I want to transform output
before the loss and backpropogation happen.
Here is my general code:
with torch.set_grad_enabled(training):
outputs = net(x_batch[:, 0], x_batch[:, 1]) # the prediction of the NN
# My issue is here:
outputs = transform_torch(outputs)
loss = my_loss(outputs, y_batch)
if training:
scheduler.step()
loss.backward()
optimizer.step()
Following the advice in How to transform output of neural network and still train? , I have a transformation function which I put my output through:
def transform_torch(predictions):
new_tensor = []
for i in range(int(len(predictions))):
arr = predictions[i]
a = arr.clone().detach()
# My transformation, which results in a positive first element, and the other elements represent decrements of the first positive element.
b = torch.negative(a)
b[0] = abs(b[0])
new_tensor.append(torch.cumsum(b, dim = 0))
# new_tensor[i].requires_grad = True
new_tensor = torch.stack(new_tensor, 0)
return new_tensor
Note: In addition to clone().detach()
, I also tried the methods described in Pytorch preferred way to copy a tensor , to similar result.
My problem is that no training actually happens with this tensor that is tranformed.
If I try to modify the tensor in-place (eg directly modify arr
), then Torch complains that I can't modify a tensor in-place with a gradient attached to it.
Any suggestions?
Calling detach
on your predictions
stops gradient propagation to your model. Nothing you do after that can change your parameters.
How about modifying your code to avoid this:
def transform_torch(predictions):
b = torch.cat([predictions[:, :1, ...].abs(), -predictions[:, 1:, ...]], dim=1)
new_tensor = torch.cumsum(b, dim=1)
return new_tensor
How about extracting grad from the tensor with something like this
grad = output.grad
and after the transformation assigning the same gradient to new tensor
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.