簡體   English   中英

y = x / sum(x, dim=0) 的反向傳播,其中張量 x 的大小為 (H,W)

[英]Back-Propagation of y = x / sum(x, dim=0) where size of tensor x is (H,W)

Q1。

我正在嘗試使用 pytorch 制作我的自定義 autograd function。

但是我在用 y = x / sum(x, dim=0) 進行分析反向傳播時遇到了問題

其中張量 x 的大小是(高度,寬度)(x 是二維的)。

這是我的代碼

class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
  ctx.save_for_backward(input)
  input = input / torch.sum(input, dim=0)

  return input

@staticmethod
def backward(ctx, grad_output):
  input = ctx.saved_tensors[0]
  H, W = input.size()
  sum = torch.sum(input, dim=0)
  grad_input = grad_output * (1/sum - input*1/sum**2)

  return grad_input

我使用 (torch.autograd import) gradcheck 來比較雅可比矩陣,

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)

結果是

在此處輸入圖像描述

請有人幫我獲得正確的反向傳播結果

謝謝!


Q2。

感謝您的回答!

由於您的幫助,我可以在 (H,W) 張量的情況下實現反向傳播。

然而,當我在 (N,H,W) 張量的情況下實現反向傳播時,我遇到了一個問題。 我認為問題在於初始化新張量。

這是我的新代碼

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input):
    ctx.save_for_backward(input)
    
    N = input.size(0)
    for n in range(N):
      input[n] /= torch.sum(input[n], dim=0)

    return input

  @staticmethod
  def backward(ctx, grad_output):
    input = ctx.saved_tensors[0]
    N, H, W = input.size()
    I = torch.eye(H).unsqueeze(-1)
    sum = input.sum(1)

    grad_input = torch.zeros((N,H,W), dtype = torch.double, requires_grad=True)
    for n in range(N):
      grad_input[n] = ((sum[n] * I - input[n]) * grad_output[n] / sum[n]**2).sum(1)

    return grad_input

畢業檢查代碼是

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.rand(2,2,2,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
print(test)

結果是在此處輸入圖像描述

我不知道為什么會出現錯誤...

您的幫助將對我實現自己的卷積網絡非常有幫助。

謝謝。 祝你今天過得愉快。

讓我們看一個帶有單列的示例,例如: [[x1], [x2], [x3]]

sumx1 + x2 + x3 ,然后標准化x將給出y = [[y1], [y2], [y3]] = [[x1/sum], [x2/sum], [x3/sum]] 您正在尋找dL/dx1dL/x2dL/x3 - 我們將它們寫為: dx1dx3 dx2 所有dL/dyi都一樣。

所以dx1等於dL/dy1*dy1/dx1 + dL/dy2*dy2/dx1 + dL/dy3*dy3/dx1 這是因為x1對相應列上的所有輸出元素都有貢獻: y1y2y3

我們有:

  • dy1/dx1 = d(x1/sum)/dx1 = (sum - x1)/sum²

  • dy2/dx1 = d(x2/sum)/dx1 = -x2/sum²

  • 同樣, dy3/dx1 = d(x3/sum)/dx1 = -x3/sum²

因此dx1 = (sum - x1)/sum²*dy1 - x2/sum²*dy2 - x3/sum²*dy3 dx3 dx2 因此,雅可比是[dxi]_i = (sum - xi)/sum²[dxi]_j = -xj/sum² (對於所有不同於ij )。

在您的實現中,您似乎缺少所有非對角線組件。

保持相同的單列示例,使用x1=2x2=3x3=5

>>> x = torch.tensor([[2.], [3.], [5.]])

>>> sum = input.sum(0)
tensor([10])

雅可比將是:

>>> J = (sum*torch.eye(input.size(0)) - input)/sum**2
tensor([[ 0.0800, -0.0200, -0.0200],
        [-0.0300,  0.0700, -0.0300],
        [-0.0500, -0.0500,  0.0500]])

對於具有多列的實現,它有點棘手,更具體地說是對角矩陣的形狀。 軸保持在最后更容易,因此我們不必為廣播而煩惱:

>>> x = torch.tensor([[2., 1], [3., 3], [5., 5]])
>>> sum = x.sum(0)
tensor([10.,  9.])

>>> diag = sum*torch.eye(3).unsqueeze(-1).repeat(1, 1, len(sum))
tensor([[[10.,  9.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [10.,  9.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [10.,  9.]]])

上面的diag具有(3, 3, 2)的形狀,其中兩位於最后一個軸上。 注意我們不需要廣播sum

不會做的是: torch.eye(3).unsqueeze(0).repeat(len(sum), 1, 1) 由於使用這種形狀 - (2, 3, 3) - 你將不得不使用sum[:, None, None] ,並且需要進一步廣播......

雅可比行列式很簡單:

>>> J = (diag - x)/sum**2
tensor([[[ 0.0800,  0.0988],
         [-0.0300, -0.0370],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [ 0.0700,  0.0741],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [-0.0300, -0.0370],
         [ 0.0500,  0.0494]]])

您可以使用任意dy向量通過操作反向傳播來檢查結果(雖然不是使用torch.ones ,但由於J ,您將得到0 s。),反向傳播后, x.grad應該等於torch.einsum('abc,bc->ac', J, dy)

您的雅可比行列式不准確:它是 4d 張量,您只計算了它的 2D 切片。

您忽略了雅可比行列式的第二行:

在此處輸入圖像描述

回答 Q2。

我自己為許多批處理案例實施了反向傳播。 我使用了 unsqueeze function 並且它起作用了。

輸入大小:(N,H,W)(N是批量大小)

forward:
  out = input / torch.sum(input, dim=1).unsqueeze(1)

backward:
  diag = torch.eye(input.size(1),  dtype=torch.double, requires_grad=True).unsqueeze(-1)
  sum = input.sum(1)
  grad_input = ((sum.unsqueeze(1).unsqueeze(1) * diag - input.unsqueeze(1)) * grad_out.unsqueeze(1) / (sum**2).unsqueeze(1).unsqueeze(1)).sum(2)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM