简体   繁体   中英

How to partition a neural network into sub-networks in Pytorch?

I'd like to partition a neural network into two sub-networks using Pytorch. To make things concrete, consider this image:

在此处输入图像描述

In 1, I've a 3x4x1 neural network. What I want is, for example during epoch 1, I'd only like to update the weights in the sub-network 1, ie, the weights that appear in the sub-network 2 must be frozen. Then again, in epoch 2, I'd like to train the weights that appear in sub-network 2 while the rest should be frozen.

How can I do that?

You can do this easily if your subnet is a subset of layers. That is, you do not need to freeze any partial layers. It is all or nothing.

For your example that would mean dividing the hidden layer into two different 2-node layers. Each would belong to exactly one of the subnetworks, which gets us back to all or nothing.

With that done, you can toggle individual layers using requires_grad . Setting this to False on the parameters will disable training and freeze the weights. To do this for an entire model, sub-model, or Module , you loop through the model.parameters() .

For your example, with 3 inputs, 1 output, and a now split 2x2 hidden layer, it might look something like this:

import torch.nn as nn
import torch.nn.functional as F

def set_grad(model, grad):
    for param in model.parameters():
        param.requires_grad = grad

class HalfFrozenModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hid1 = torch.nn.Linear(3, 2)
        self.hid2 = torch.nn.Linear(3, 2)
        self.out = torch.nn.Linear(4, 1)

    def set_freeze(self, hid1=False, hid2=False):
        set_grad(self.hid1, not hid1)
        set_grad(self.hid2, not hid2)

    def forward(self, inp):
        hid1 = self.hid1(inp)
        hid2 = self.hid2(inp)
        hidden = torch.cat([hid1, hid2], 1)
        return self.out(F.relu(hidden))

Then you can train one half or the other like so:

model = HalfFrozenModel()
model.set_freeze(hid1=True)
# Do some training.
model.set_freeze(hid2=True)
# Do some more training.
# ...

If you happen to use fastai , then there is a concept of layer groups that is also used for this. The fastai documentation goes into some detail about how that works.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM