简体   繁体   中英

How to initialize weights in a pytorch model

I've got a fairly straight forward problem here. I've just finished re-configuring a network by replacing nn.Upsample with the upConv sequential container shown in the code below. I've verified that everything is lined up by running summary(UNetPP, (3, 128, 128)) which runs with no issue.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class blockUNetPP(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class upConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upc = nn.Sequential(
                                 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                 nn.Conv2d(in_ch, out_ch*2, 3, stride=1, padding=1),
                                 nn.BatchNorm2d(out_ch*2),
                                 nn.ReLU(inplace=True)
                                 )
    def forward(self, x):
        out = self.upc(x)
        return out

My issue is that when I try to start training the model I get the following issue:

Traceback (most recent call last):
  File "runTrain.py", line 90, in <module>
    netG.apply(weights_init)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 289, in apply
    module.apply(fn)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 290, in apply
    fn(self)
  File "D:\Thesis Models\Deep_learning_models\UNet\train\NetC.py", line 8, in weights_init
    m.weight.data.normal_(0.0, 0.02)
  File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 594, in __getattr__
    type(self).__name__, name))
AttributeError: 'upConv' object has no attribute 'weight'

I've looked up solutions which suggest looping over container modules, but I'm already doing this with weights_init(m) . Could someone explain whats wrong with my current setup?

You are deciding how to initialise the weight by checking that the class name includes Conv with classname.find('Conv') . Your class has the name upConv , which includes Conv , therefore you try to initialise its attribute .weight , but that doesn't exist.

Either rename your class or make the condition more strict, such as classname.find('Conv2d') . The strictest approach would be to check whether it's an instance of nn.Conv2d , instead of looking at the name of the class.

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM