简体   繁体   中英

Efficient batch derivative operations in PyTorch

I am using Pytorch to implement a neural network that has (say) 5 inputs and 2 outputs

class myNetwork(nn.Module):
   def __init__(self):
      super(myNetwork,self).__init__()
      self.layer1 = nn.Linear(5,32)
      self.layer2 = nn.Linear(32,2)
   def forward(self,x):
      x = torch.relu(self.layer1(x))
      x = self.layer2(x)
      return x

Obviously, I can feed this an (N x 5) Tensor and get an (N x 2) result,

net = myNetwork()
nbatch = 100
inp = torch.rand([nbatch,5])
inp.requires_grad = True
out = net(inp)

I would now like to compute the derivatives of the NN output with respect to one element of the input vector (let's say the 5th element), for each example in the batch. I know I can calculate the derivatives of one element of the output with respect to all inputs using torch.autograd.grad , and I could use this as follows:

deriv = torch.zeros([nbatch,2])
for i in range(nbatch):
   for j in range(2):
      deriv[i,j] = torch.autograd.grad(out[i,j],inp,retain_graph=True)[0][i,4]

However, this seems very inefficient: it calculates the gradient of out[i,j] with respect to every single element in the batch, and then discards all except one. Is there a better way to do this?

By virtue of backpropagation, if you did only compute the gradient w.r.ta single input, the computational savings wouldn't necessarily amount to much, you would only save some in the first layer, all layers afterwards need to be backpropagated either way.

So this may not be the optimal way, but it doesn't actually create much overhead, especially if your network has many layers.

By the way, is there a reason that you need to loop over nbatch ? If you wanted the gradient of each element of a batch w.r.ta parameter, I could understand that, because pytorch will lump them together, but you seem to be solely interested in the input...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM