繁体   English   中英

如何使用未经训练的特征提取层训练 model? (火炬)

[英]How can I train a model with untrained feature extraction layers? (PyTorch)

我想知道如何使用未经训练的特征提取网络保持批量规范层处于活动状态。

这是否会被视为使用“未经训练”的网络进行特征提取?:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

上面应该返回 [batch size, 2208] 的张量,但是,我想通过声明pretrained=False来确保我基本上是从未经训练的网络中提取特征。

然后我使用以下内容来定义分类器层:

class MyDenseNetDens(torch.nn.Module):
    def __init__(self, nb_out=2):
        super().__init__()
        self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
        self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
        self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
        
    def forward(self, x):
        x = self.dens1(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens2(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens3(x)
        return x

最后在这里加入他们:

class MyDenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mrnc = MyDenseNetConv()
        self.mrnd = MyDenseNetDens()
    def forward(self, x):
        x = self.mrnc(x)
        x = self.mrnd(x)
        return x 

densenet = MyDenseNet()
densenet.to(device)
densenet.train()

如果我允许它进行训练,例如通过应用densenet.train()这是否足以允许为每个小批量生成批量归一化统计信息,以及允许在期间学习和应用运行均值和标准偏差推理,同时保持卷积层未经训练?

不幸的是,它无论如何都会更新运行状态,因为 self.mrnc 中的默认self.mrnc仍然使用运行均值和运行方差进行初始化。 您需要做的就是关闭它们:

class MyDenseNetConv(torch.nn.Module): # I renamed it to be correct to your code
    def __init__(self):
        super(MyDenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        for n, v in original_model.named_modules(): # this part is to remove the tracking stat option
            if 'norm' in n:
                v.track_running_stats=False
                v.running_mean = None
                v.running_var = None
                v.num_batches_tracked = None
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM