简体   繁体   中英

How can I train a model with untrained feature extraction layers? (PyTorch)

I was wondering how can I keep my batch norm layers active using an untrained feature extraction network.

Would this be considered feature extraction with an "untrained" network?:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

The above should return a tensor of [batch size, 2208], however, I want to make sure that by stating pretrained=False that I am basically extracting features from an untrained network.

I then use the following to define the classifier layers:

class MyDenseNetDens(torch.nn.Module):
    def __init__(self, nb_out=2):
        super().__init__()
        self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
        self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
        self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
        
    def forward(self, x):
        x = self.dens1(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens2(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens3(x)
        return x

and finally join them together here:

class MyDenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mrnc = MyDenseNetConv()
        self.mrnd = MyDenseNetDens()
    def forward(self, x):
        x = self.mrnc(x)
        x = self.mrnd(x)
        return x 

densenet = MyDenseNet()
densenet.to(device)
densenet.train()

If I allow this to train, for example by applying densenet.train() will this be sufficient in allowing batch normalisation statistics to be generated for each mini-batch as well as allow for the running means and standard deviations to be learnt and applied during inference, while keeping the convolutional layers untrained?

Unfortunately, it will update the running stat anyway since the default batchnorm in your self.mrnc still initialized with the running mean and the running variance. All you need to do is shut them down:

class MyDenseNetConv(torch.nn.Module): # I renamed it to be correct to your code
    def __init__(self):
        super(MyDenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        for n, v in original_model.named_modules(): # this part is to remove the tracking stat option
            if 'norm' in n:
                v.track_running_stats=False
                v.running_mean = None
                v.running_var = None
                v.num_batches_tracked = None
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM