簡體   English   中英

PyTorch中的高效批量導數操作

[英]Efficient batch derivative operations in PyTorch

我正在使用 Pytorch 來實現具有(比如說)5 個輸入和 2 個輸出的神經網絡

class myNetwork(nn.Module):
   def __init__(self):
      super(myNetwork,self).__init__()
      self.layer1 = nn.Linear(5,32)
      self.layer2 = nn.Linear(32,2)
   def forward(self,x):
      x = torch.relu(self.layer1(x))
      x = self.layer2(x)
      return x

顯然,我可以輸入一個 (N x 5) 張量並得到 (N x 2) 結果,

net = myNetwork()
nbatch = 100
inp = torch.rand([nbatch,5])
inp.requires_grad = True
out = net(inp)

我現在想為批處理中的每個示例計算 NN output 關於輸入向量的一個元素(假設是第 5 個元素)的導數。 我知道我可以使用torch.autograd.grad計算 output 的一個元素相對於所有輸入的導數,我可以按如下方式使用它:

deriv = torch.zeros([nbatch,2])
for i in range(nbatch):
   for j in range(2):
      deriv[i,j] = torch.autograd.grad(out[i,j],inp,retain_graph=True)[0][i,4]

然而,這似乎非常低效:它計算out[i,j]關於批次中每個元素的梯度,然后丟棄除一個之外的所有元素。 有一個更好的方法嗎?

憑借反向傳播,如果您只計算梯度 w.r.ta 單個輸入,則計算節省不一定很多,您只會在第一層節省一些,之后的所有層都需要反向傳播.

因此,這可能不是最佳方式,但實際上並不會產生太多開銷,尤其是在您的網絡有很多層的情況下。

順便說一句,是否有理由需要循環nbatch 如果您想要批次 w.r.ta 參數的每個元素的梯度,我可以理解,因為 pytorch 會將它們放在一起,但您似乎只對輸入感興趣......

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM