[英]Derivative of Scalar Expansion in PyTorch
我正在努力在 Rust 中實現一個非常簡單的自動差異庫,以擴展我對它是如何完成的知識。 我幾乎一切正常,但在實現負對數似然時,我意識到我對如何處理以下場景的導數有些困惑(我在下面的 PyTorch 中編寫了它)。
x = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grads=True)
y = x - torch.sum(x)
我環顧四周,進行了實驗,但對這里實際發生的事情仍然有些困惑。 我知道上面等式關於 x 的導數是 [-2, -2, -2],但是有很多方法可以到達那里,當我將等式擴展為以下內容時:
x = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grads=True)
y = torch.exp(x - torch.sum(x))
我完全迷路了,不知道它是如何得出 x 的梯度的。
我假設上面的方程式被重寫為這樣的東西:
y = (x - [torch.sum(x), torch.sum(x), torch.sum(x)])
但我不確定,而且我真的很難找到有關標量擴展為向量或實際發生的任何事情的信息。 如果有人能指出我正確的方向,那就太棒了!
如果有幫助,我可以包括上述方程的梯度 pytorch 計算。
首先,有幾件事,參數是requires_grad
而不是require_grads
。 其次,您只能為浮點數或復雜數據類型要求梯度。
現在,標量加法/乘法(注意減法/除法可以看作是加 -ve 數/乘以分數)簡單地將標量與張量的所有元素相加/相乘。 因此,
x = torch.tensor([1., 2., 3.], requires_grad=True)
y = x - 1
評估為:
y = tensor([-1., 0., 1.], grad_fn=<SubBackward0>)
因此,在您的情況下, torch.sum(x)
基本上是從張量x
的所有元素中減去的標量。
如果您對漸變部分更感興趣,請查看 pytorch 關於 autograd 的文檔 [ ref ]。 它聲明如下:
使用鏈式法則對圖進行微分。 如果任何張量是非標量的(即它們的數據有多個元素)並且需要梯度,那么將計算雅可比向量積,在這種情況下 function 還需要指定 grad_tensors。 它應該是一個匹配長度的序列,包含雅可比向量積中的“向量”,通常是微分 function w.r.t 的梯度。 相應的張量(
None
是所有不需要梯度張量的張量的可接受值)。
如果不進行任何修改,您的代碼將無法與 PyTorch 一起使用,因為它沒有指定 w.r.t 到y
的梯度是什么。 你需要他們backward
打電話。 從你的 all -2 結果來看,我認為梯度必須是所有的。
“標量擴展”稱為廣播。 如您所知,只要兩個張量操作數的形狀不匹配,就會執行廣播。 我的猜測是,它的實現方式與 PyTorch 中的任何其他操作相同,知道如何在給定梯度 w.r.t 其輸出的情況下計算其輸入的梯度 w.r.t。 下面給出了一個簡單的示例,其中 (a) 適用於您給定的測試用例,並且 (b) 允許我們仍然使用 PyTorch 的 autograd 來自動計算梯度(另請參閱PyTorch 關於擴展 autograd 的文檔):
class Broadcast(torch.autograd.Function):
def forward(ctx, x: torch.Tensor, length: int) -> torch.Tensor:
assert x.ndim == 0, "input must be a scalar tensor"
assert length > 0, "length must be greater than zero"
return x.new_full((length,), x.item())
def backward(ctx, grad: torch.Tensor) -> Tuple[torch.Tensor, None]:
return grad.sum(), None
現在,通過設置broadcast = Broadcast.apply
我們可以自己調用廣播而不是讓 PyTorch 自動執行。
x = torch.tensor([1., 2., 3.], requires_grad=True)
y = x - broadcast(torch.sum(x), x.size(0))
y.backward(torch.ones_like(y))
assert torch.allclose(torch.tensor(-2.), x.grad)
請注意,我不知道 PyTorch 實際上是如何實現的。 上面的實現只是為了說明如何實現廣播操作以使自動微分起作用,希望能回答您的問題。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.