简体   繁体   中英

How to add weight normalisation to PyTorch's pretrained VGG16?

I want to add weight normalization to PyTorch pre-trained VGG-16. One possible solution which I can think of is as follows,

from torch.nn.utils import weight_norm as wn
import torchvision.models as models

class ResnetEncoder(nn.Module):
    def __init__(self):
        super(ResnetEncoder, self).__init__()
        ...
        self.encoder = models.vgg16(pretrained=True).features
        ...
    def forward(self, input_image):
        self.features = []
        x = (input_image - self.mean) / self.std
        
        self.features.append(self.encoder(x))
        ...

        return self.features

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encoder = ResnetEncoder() # this is basically VGG16
        self.decoder = DepthDecoder(self.encoder.num_ch_enc)
        for k,m in self.encoder.encoder._modules.items():
            if isinstance(m,nn.Conv2d):
                m = wn(m)

    def forward(self,x):
        return self.decoder(self.encoder(x))

vgg_backbone_model = Net()
vgg_backbone_model.train()
...

But I do not know if this is the correct way to add weight normalization to pre-trained VGG16.

You should be using nn.Module.modules instead of accessing the _modules attribute.

Doing m = wn(m) won't update the parameters of the layer but instead make a copy and overwrite the local variable m . Instead, you should override the layer itself from the nn.Module , one way to do such thing is to use setattr :

for k, v in model.named_modules():
    if isinstance(v, nn.Conv2d):
        setattr(model, k, weight_norm(v))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM