简体   繁体   中英

How to update a part of torch.nn.Parameter

For updating a part of parameters defined by torch.nn.Parameter. I have tested the following three ways, but only one works.

#(1)

import torch
class NET(torch.nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.params = torch.ones(4)
        self.P = torch.nn.Parameter(torch.ones(1))
        self.params[1] = self.P
    def forward(self, x):
        y = x * self.params
        return y.sum()

net = NET()
x = torch.rand(4)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
for _ in range(10):
    optim.zero_grad()
    loss = net(x)
    loss.backward()
    optim.step()

# RuntimeError: Trying to backward through the graph a second time

#(2)

import torch
class NET(torch.nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.P = torch.nn.Parameter(torch.ones(1))
    def forward(self, x):
        params = torch.ones(4)
        params[1] = self.P
        y = x * params
        return y.sum()

net = NET()
x = torch.rand(4)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
for _ in range(10):
    optim.zero_grad()
    loss = net(x)
    loss.backward()
    optim.step()

# It works, but the operations of Create and Assign are needed in each forward.

#(3)

import torch
class NET(torch.nn.Module):
    def __init__(self):
        super(NET, self).__init__()
        self.params = torch.nn.Parameter(torch.ones(4))
    def forward(self, x):
        y = x * self.params
        return y.sum()

net = NET()
net.params[1].requires_grad = False
x = torch.rand(4)
optim = torch.optim.Adam(net.parameters(), lr=0.001)
for _ in range(10):
    optim.zero_grad()
    loss = net(x)
    loss.backward()
    optim.step()

# RuntimeError: you can only change requires_grad flags of leaf variables.

I wonder how to update a part of parameters in the ways (1) and (3).

A small note on the use of requires_grad and nn.Parameter :

  1. If you had to freeze a sub-module of you nn.Module , you would require the use of requires_grad_ . However, you cannot partially require gradients on a tensor.

  2. A nn.Parameter is a wrapper which allows a given torch.Tensor to be registered inside a nn.Module . By default, the wrapped tensor will require gradient computation.

You must therefore absolutely have your parameter tensor defined as:

nn.Parameter(torch.ones(4))

And not as:

self.params = torch.ones(4)

Ultimately you should check the content of your registered parameters with nn.Module#parameters before loading them inside an optimizer.

Your first code #1 crashes because you are performing multiple backpropagations on the same tree without explicitly setting the retain_graph to True . The following process works fine:

for _ in range(10):
    optim.zero_grad()
    x = torch.rand(4) # new x
    loss = net(x)
    loss.backward()
    optim.step()

Your second code #2 is correct because you are assigning the tensor which requires gradient to a different tensor. A minimal implementation to check that the gradient is indeed computed on P is as follows:

# reassign parameter tensor to bigger tensor
>>> P = torch.ones(1, requires_grad=True)
>>> params = torch.ones(4)
>>> params[1] = P

# inference and backpropagation
>>> (torch.rand(4)*params).sum().backward()
>>> P.grad
tensor([0.46701658])

Your third code #3 is invalid because you are requiring gradient computation on part of the code which is not possible:

net.params[1].requires_grad = False # invalid

An alternative way to do it instead is by masking the gradient after the back propagation has been done on the parameters:

class NET(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.Parameter(torch.ones(4))

    def forward(self, x):
        y = x * self.params
        return y.sum()


net = NET()
net(torch.rand(4)).backward()
net.params.grad[1] = 0

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM