簡體   English   中英

如何將解碼器的輸入與預訓練的 Resnet18 編碼器匹配?

[英]How can I match a Decoder's input to a Pretrained Resnet18 Encoder?

我正在嘗試構建一個帶有跳過連接的自定義解碼器,以使用預訓練的 Resnet18 編碼器來運行圖像分割任務。 班級總數為150個。

The Resnet18 Encoder has fc output of 512. In order to match the Encoder's output to Decoder's input, I am trying to set the Conv layers of Decoder such that it matches the output from Encoder ie [151, 512, 1, 1]. 但是,無論我制作什么層組合,我都無法匹配輸入和 output 張量。

這是解碼器代碼的相關部分

class ResNet18Transpose(nn.Module):

    def __init__(self, transblock, layers, num_classes=150):
        self.inplanes = 512
        super(ResNet18Transpose, self).__init__()
        
        self.deconv1 = self._make_transpose(transblock, 256 * transblock.expansion, layers[0], stride=2)
        self.deconv2 = self._make_transpose(transblock, 128 * transblock.expansion, layers[1], stride=2)
        self.deconv3 = self._make_transpose(transblock, 64 * transblock.expansion, layers[2], stride=2)
        self.deconv4 = self._make_transpose(transblock, 32 * transblock.expansion, layers[3], stride=2)
        

        self.skip0 = self._make_skip_layer(64, 64 * transblock.expansion)
        self.skip1 = self._make_skip_layer(128, 64 * transblock.expansion)
        self.skip2 = self._make_skip_layer(256, 64 * transblock.expansion)
        self.skip3 = self._make_skip_layer(512, 128 * transblock.expansion)
        
        self.inplanes = 64
        self.final_conv = self._make_transpose(transblock, 64 * transblock.expansion, 3)
        
        self.final_deconv = nn.ConvTranspose2d(self.inplanes * transblock.expansion, num_classes, kernel_size=2,
                                               stride=2, padding=0, bias=True)
        
        self.out6_conv = nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, bias=True)
        self.out5_conv = nn.Conv2d(128 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out4_conv = nn.Conv2d(128 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out3_conv = nn.Conv2d(64 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out2_conv = nn.Conv2d(32 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

Decoder 的前向塊如下:

def forward(self, x, labels=None, sparse_mode=False, use_skip=True):
        [in0, in1, in2, in3, in4] = x
        if labels:
            [lab0, lab1, lab2, lab3, lab4] = labels

        out6 = self.out6_conv(in4)
        
        if sparse_mode:
            if labels:
                mask4 = (lab4==0).unsqueeze(1).repeat(1,in4.shape[1],1,1).type(in4.dtype)
            else:
                mask4 = (torch.argmax(out6, dim=1)==0).unsqueeze(1).repeat(1,in4.shape[1],1,1).type(in4.dtype)
            in4 = in4 * mask4

        # upsample 1
        x = self.deconv1(in4)
        out5 = self.out5_conv(x)
        
        if sparse_mode:
            if labels:
                mask3 = (lab3==0).unsqueeze(1).repeat(1,in3.shape[1],1,1).type(in3.dtype)
            else:
                mask3 = (torch.argmax(out5, dim=1)==0).unsqueeze(1).repeat(1,in3.shape[1],1,1).type(in3.dtype)
            in3 = in3 * mask3

        if use_skip:
            x = x + self.skip3(in3)
        
        # upsample 2
        x = self.deconv2(x)
        out4 = self.out4_conv(x)
        
        if sparse_mode:
            if labels:
                mask2 = (lab2==0).unsqueeze(1).repeat(1,in2.shape[1],1,1).type(in2.dtype)
            else:
                mask2 = (torch.argmax(out4, dim=1)==0).unsqueeze(1).repeat(1,in2.shape[1],1,1).type(in2.dtype)
            in2 = in2 * mask2

        if use_skip:
            x = x + self.skip2(in2)
        
        # upsample 3
        x = self.deconv3(x)
        out3 = self.out3_conv(x)
        
        if sparse_mode:
            if labels:
                mask1 = (lab1==0).unsqueeze(1).repeat(1,in1.shape[1],1,1).type(in1.dtype)
            else:
                mask1 = (torch.argmax(out3, dim=1)==0).unsqueeze(1).repeat(1,in1.shape[1],1,1).type(in1.dtype)
            in1 = in1 * mask1

        if use_skip:
            x = x + self.skip1(in1)
        
        # upsample 4
        x = self.deconv4(x)
        out2 = self.out2_conv(x)
        
        if sparse_mode:
            if labels:
                mask0 = (lab0==0).unsqueeze(1).repeat(1,in0.shape[1],1,1).type(in0.dtype)
            else:
                mask0 = (torch.argmax(out2, dim=1)==0).unsqueeze(1).repeat(1,in0.shape[1],1,1).type(in0.dtype)
            in0 = in0 * mask0

        if use_skip:
            x = x + self.skip0(in0)
        
        # final
        x = self.final_conv(x)
        out1 = self.final_deconv(x)

        return [out6, out5, out4, out3, out2, out1]

我收到以下錯誤:

 File "/project/xfu/aamir/Golden-QGN/models/resnet.py", line 447, in forward
    out6 = self.out6_conv(in4)
  File "/project/xfu/aamir/anaconda3/envs/QGN/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/project/xfu/aamir/anaconda3/envs/QGN/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 419, in forward
    return self._conv_forward(input, self.weight)
  File "/project/xfu/aamir/anaconda3/envs/QGN/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 415, in _conv_forward
    return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Given groups=1, weight of size [151, 1024, 1, 1], expected input[8, 512, 8, 8] to have 1024 channels, but got 512 channels instead
Exception in thread Thread-1:

如果我按如下方式更改解碼器層:

self.out6_conv = nn.Conv2d(512, num_classes, kernel_size=1, stride=1, bias=True)
        self.out5_conv = nn.Conv2d(256 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out4_conv = nn.Conv2d(128 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out3_conv = nn.Conv2d(64 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)
        self.out2_conv = nn.Conv2d(32 * transblock.expansion, num_classes, kernel_size=1, stride=1, bias=True)

我收到以下錯誤。 請注意,輸入張量也發生了變化。

 File "/project/xfu/aamir/anaconda3/envs/QGN/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 415, in _conv_forward
    return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 64, 3, 3], expected input[8, 32, 128, 128] to have 64 channels, but got 32 channels instead

我還嘗試將 Resnet18 編碼器的 fc output 更改為 1024 而不是 512,如下所示:

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(load_url(model_urls['resnet18']))   
        model.fc = torch.nn.Linear(1024, 150)
    return model

但似乎沒有任何效果。 我也無法在 github / 互聯網上找到基於跳過連接的 resnet18 解碼器。 任何幫助將不勝感激。

注意:我只想在 Resnet18 上工作。 此外,通過數據加載器輸入到網絡的圖像與 Resnet50 編碼器 + 自定義解碼器一起工作得非常好。 我還嘗試將作物大小從 0 更改為 128 到 256 到 512,但徒勞無功。

這是在 models.py 中運行的用於圖像分割的代碼。 我嘗試使用 set_trace() 方法來調試代碼。 代碼在以下代碼中的(pred, pred_quad) = self.decoder(fmap, labels_scaled)行之前停止運行

class SegmentationModule(SegmentationModuleBase):
    
    def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None, quad_sup=False, running_avg_param=0.99):
        super(SegmentationModule, self).__init__()
        self.encoder = net_enc
        self.decoder = net_dec
        self.crit = crit
        if deep_sup_scale:
          if deep_sup_scale < 0:
              self.adapt_weights = True
              self.running_avg_param = running_avg_param
              deep_sup_scale = 1
          else:
              self.adapt_weights = False
          self.loss_weights = [(deep_sup_scale**(i+1)) for i in range(5)]
        self.quad_sup = quad_sup
            
    def forward(self, feed_dict, *, segSize=None):
        inputs = feed_dict['img_data'].cuda()
        if segSize is None: # training
            labels_orig_scale = feed_dict['seg_label_0'].cuda()             
            labels_scaled = []
            fmap = self.encoder(inputs, return_feature_maps=True)
            if self.quad_sup:
                labels_scaled.append(feed_dict['seg_label_1'].cuda())
                labels_scaled.append(feed_dict['seg_label_2'].cuda())
                labels_scaled.append(feed_dict['seg_label_3'].cuda())
                labels_scaled.append(feed_dict['seg_label_4'].cuda())
                labels_scaled.append(feed_dict['seg_label_5'].cuda())
                (pred, pred_quad) = self.decoder(fmap, labels_scaled)
            else:
                pred = self.decoder(fmap)
            loss = self.crit(pred, labels_orig_scale)
            if self.quad_sup:
                loss_orig = loss
                for i in range(len(pred_quad)):
                    loss_quad = self.crit(pred_quad[i], labels_scaled[i])
                    loss = loss + loss_quad * self.loss_weights[i]
                    if self.adapt_weights:
                        self.loss_weights[i] = self.running_avg_param * self.loss_weights[i] + \
                        (1 - self.running_avg_param) * (loss_quad/loss_orig).data.cpu().numpy()

            acc = self.pixel_acc(pred, labels_orig_scale, self.quad_sup)
            return loss, acc
        else: # inference
            if 'qtree' in feed_dict:
                labels_scaled = [feed_dict['qtree'][l].cuda() for l in range(1,6)]                
            else:
                labels_scaled = None
            pred = self.decoder(self.encoder(inputs, return_feature_maps=True), labels_scaled, segSize=segSize)
            return pred

所以這就是我解決錯誤的方法。

  1. 禁用跳過連接。 不知道為什么解碼器不適用於跳過連接。
  2. 我給解碼器輸入了不正確的層 [6, 3, 4, 3]。 即對於 Resnet18,層應該是 [2,2,2,2]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM