簡體   English   中英

如何使用未經訓練的特征提取層訓練 model? (火炬)

[英]How can I train a model with untrained feature extraction layers? (PyTorch)

我想知道如何使用未經訓練的特征提取網絡保持批量規范層處於活動狀態。

這是否會被視為使用“未經訓練”的網絡進行特征提取?:

class DenseNetConv(torch.nn.Module):
    def __init__(self):
        super(DenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

上面應該返回 [batch size, 2208] 的張量,但是,我想通過聲明pretrained=False來確保我基本上是從未經訓練的網絡中提取特征。

然后我使用以下內容來定義分類器層:

class MyDenseNetDens(torch.nn.Module):
    def __init__(self, nb_out=2):
        super().__init__()
        self.dens1 = torch.nn.Linear(in_features=2208, out_features=512)
        self.dens2 = torch.nn.Linear(in_features=512, out_features=128)
        self.dens3 = torch.nn.Linear(in_features=128, out_features=nb_out)
        
    def forward(self, x):
        x = self.dens1(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens2(x)
        x = torch.nn.functional.selu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        x = self.dens3(x)
        return x

最后在這里加入他們:

class MyDenseNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mrnc = MyDenseNetConv()
        self.mrnd = MyDenseNetDens()
    def forward(self, x):
        x = self.mrnc(x)
        x = self.mrnd(x)
        return x 

densenet = MyDenseNet()
densenet.to(device)
densenet.train()

如果我允許它進行訓練,例如通過應用densenet.train()這是否足以允許為每個小批量生成批量歸一化統計信息,以及允許在期間學習和應用運行均值和標准偏差推理,同時保持卷積層未經訓練?

不幸的是,它無論如何都會更新運行狀態,因為 self.mrnc 中的默認self.mrnc仍然使用運行均值和運行方差進行初始化。 您需要做的就是關閉它們:

class MyDenseNetConv(torch.nn.Module): # I renamed it to be correct to your code
    def __init__(self):
        super(MyDenseNetConv,self).__init__()
        original_model = models.densenet161(pretrained=False)
        for n, v in original_model.named_modules(): # this part is to remove the tracking stat option
            if 'norm' in n:
                v.track_running_stats=False
                v.running_mean = None
                v.running_var = None
                v.num_batches_tracked = None
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM