简体   繁体   English

PyTorch中的高效批量导数操作

[英]Efficient batch derivative operations in PyTorch

I am using Pytorch to implement a neural network that has (say) 5 inputs and 2 outputs我正在使用 Pytorch 来实现具有(比如说)5 个输入和 2 个输出的神经网络

class myNetwork(nn.Module):
   def __init__(self):
      super(myNetwork,self).__init__()
      self.layer1 = nn.Linear(5,32)
      self.layer2 = nn.Linear(32,2)
   def forward(self,x):
      x = torch.relu(self.layer1(x))
      x = self.layer2(x)
      return x

Obviously, I can feed this an (N x 5) Tensor and get an (N x 2) result,显然,我可以输入一个 (N x 5) 张量并得到 (N x 2) 结果,

net = myNetwork()
nbatch = 100
inp = torch.rand([nbatch,5])
inp.requires_grad = True
out = net(inp)

I would now like to compute the derivatives of the NN output with respect to one element of the input vector (let's say the 5th element), for each example in the batch.我现在想为批处理中的每个示例计算 NN output 关于输入向量的一个元素(假设是第 5 个元素)的导数。 I know I can calculate the derivatives of one element of the output with respect to all inputs using torch.autograd.grad , and I could use this as follows:我知道我可以使用torch.autograd.grad计算 output 的一个元素相对于所有输入的导数,我可以按如下方式使用它:

deriv = torch.zeros([nbatch,2])
for i in range(nbatch):
   for j in range(2):
      deriv[i,j] = torch.autograd.grad(out[i,j],inp,retain_graph=True)[0][i,4]

However, this seems very inefficient: it calculates the gradient of out[i,j] with respect to every single element in the batch, and then discards all except one.然而,这似乎非常低效:它计算out[i,j]关于批次中每个元素的梯度,然后丢弃除一个之外的所有元素。 Is there a better way to do this?有一个更好的方法吗?

By virtue of backpropagation, if you did only compute the gradient w.r.ta single input, the computational savings wouldn't necessarily amount to much, you would only save some in the first layer, all layers afterwards need to be backpropagated either way.凭借反向传播,如果您只计算梯度 w.r.ta 单个输入,则计算节省不一定很多,您只会在第一层节省一些,之后的所有层都需要反向传播.

So this may not be the optimal way, but it doesn't actually create much overhead, especially if your network has many layers.因此,这可能不是最佳方式,但实际上并不会产生太多开销,尤其是在您的网络有很多层的情况下。

By the way, is there a reason that you need to loop over nbatch ?顺便说一句,是否有理由需要循环nbatch If you wanted the gradient of each element of a batch w.r.ta parameter, I could understand that, because pytorch will lump them together, but you seem to be solely interested in the input...如果您想要批次 w.r.ta 参数的每个元素的梯度,我可以理解,因为 pytorch 会将它们放在一起,但您似乎只对输入感兴趣......

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM