繁体   English   中英

在 PyTorch Function 中使用 PyTorch 模块的正确方法是什么?

[英]What is the correct way to use a PyTorch Module inside a PyTorch Function?

我们有一个自定义的torch.autograd.Function z(x, t)以一种不适合直接自动微分的方式计算 output y ,并计算了关于其输入xt的操作的雅可比行列式,因此我们可以实施backward的方法。

然而,该操作涉及对 neural.network 进行多次内部调用,我们现在将其实现为一堆torch.nn.Linear对象,包裹在net中,一个torch.nn.Module 在数学上,这些由t参数化。

有什么方法可以让net本身成为zforward方法的输入吗? 然后,我们将从backward返回上游梯度Dy和参数 Jacobia dydt_i的产品列表,每个参数ti都是net的子级(除了Dy*dydx ,尽管 x 是数据并且不需要梯度累积)。

或者我们是否真的需要取t (实际上是单个t_i的列表),并在z.forward中内部重建net中所有Linear层的动作?

我想你可以创建一个继承torch.autograd.Function的自定义仿函数,并使forwardbackward方法成为非静态的(即删除本例中的@staticmethod ,这样net就可以成为你的仿函数的一个属性。看起来像

class MyFunctor(torch.nn.autograd.Function):
    def __init(net):
         self.net = net
    
     def forward(ctx, x, t):
         #store x and t in ctx in the way you find useful
         # not sure how t is involved here
         return self.net(x) 

     def backward(ctx, grad):
         # do your backward stuff

net = nn.Sequential(nn.Linear(...), ...)
z = MyFunctor(net)
y = z(x, t)

这将产生一条警告,表明您正在使用不推荐使用的传统方式创建 autograd 函数(因为非静态方法),并且在反向传播后将net中的梯度归零时需要格外小心。 所以不是很方便,但我不知道有任何更好的方法来拥有有状态的 autograd function。

我正在做类似的事情,其中 PyTorch 功能的 static 限制很麻烦。 与 trialNerror 的答案类似,我保留了 PyTorch function 方法 static 并传入函数供它们使用,这解决了使仿函数非静态的问题:

class NonStaticBackward(Function):
    @staticmethod
    def forward(ctx, backward_fn, input):
        ctx.backward_fn = backward_fn
        # ... do other stuff
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # Call into our non-static backward function

        # Since we passed in the backward function as input, 
        # PyTorch expects a placeholder grad for it. 
        return None, ctx.backward_fn(ctx, grad_output)

每次向后传递 function 都很烦人,所以我通常将其包装起来:

def my_non_static_backward(ctx, grad_output):
    print("Hello from backward!")
    return grad_output

my_fn = lambda x: NonStaticBackward.apply(my_non_static_backward, x)

y = my_fn(Tensor([1, 2, 3]))

这样,您可以将 grad function 写在它可以访问所需内容的地方:无需通过net

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM