[英]Efficient batch derivative operations in PyTorch
我正在使用 Pytorch 來實現具有(比如說)5 個輸入和 2 個輸出的神經網絡
class myNetwork(nn.Module):
def __init__(self):
super(myNetwork,self).__init__()
self.layer1 = nn.Linear(5,32)
self.layer2 = nn.Linear(32,2)
def forward(self,x):
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x
顯然,我可以輸入一個 (N x 5) 張量並得到 (N x 2) 結果,
net = myNetwork()
nbatch = 100
inp = torch.rand([nbatch,5])
inp.requires_grad = True
out = net(inp)
我現在想為批處理中的每個示例計算 NN output 關於輸入向量的一個元素(假設是第 5 個元素)的導數。 我知道我可以使用torch.autograd.grad
計算 output 的一個元素相對於所有輸入的導數,我可以按如下方式使用它:
deriv = torch.zeros([nbatch,2])
for i in range(nbatch):
for j in range(2):
deriv[i,j] = torch.autograd.grad(out[i,j],inp,retain_graph=True)[0][i,4]
然而,這似乎非常低效:它計算out[i,j]
關於批次中每個元素的梯度,然后丟棄除一個之外的所有元素。 有一個更好的方法嗎?
憑借反向傳播,如果您只計算梯度 w.r.ta 單個輸入,則計算節省不一定很多,您只會在第一層節省一些,之后的所有層都需要反向傳播.
因此,這可能不是最佳方式,但實際上並不會產生太多開銷,尤其是在您的網絡有很多層的情況下。
順便說一句,是否有理由需要循環nbatch
? 如果您想要批次 w.r.ta 參數的每個元素的梯度,我可以理解,因為 pytorch 會將它們放在一起,但您似乎只對輸入感興趣......
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.