简体   繁体   English

如何在 Pytorch 中获得部分预训练的 model?

[英]How to get part of pre trained model in Pytorch?

I have made an autoencoder model named AutoEncoderNew() , and after training, I want the encoder and the decoder part separately for some experiments on new samples.我制作了一个名为AutoEncoderNew()的自动编码器 model ,训练后,我希望编码器和解码器部分分别用于新样本的一些实验。

The model class is as follows: model class如下:

class AutoEncoderNew(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, cond=0, ngf=32, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9):
        assert(n_blocks >= 0)
        super(AutoEncoderNew, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.cond = cond

        self.filters = []
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        # ---
        # Maybe cut the model in half by using increments of 4?
        model01 = [nn.Conv2d(3, ngf*4, kernel_size=5, padding=2, stride=2,
                            bias=use_bias),
                  norm_layer(ngf*4),
                  nn.ReLU(0.2)]
        model02 = [nn.Conv2d(ngf*4, ngf*8, kernel_size=5,
                             stride=2, padding=2, bias=use_bias),
                   norm_layer(ngf*8),
                   nn.ReLU(0.2)]
        model03 = [nn.Conv2d(ngf*8, ngf*12, kernel_size=5,
                             stride=2, padding=2, bias=use_bias),
                   norm_layer(ngf*12),
                   nn.ReLU(0.2)]
        model03 += [nn.Conv2d(ngf*12, ngf*16, kernel_size=5,
                             stride=2, padding=2, bias=use_bias),
                   norm_layer(ngf*16),
                   nn.ReLU(0.2)]
        model04 = [nn.Conv2d(ngf*16, ngf*16, kernel_size=3,
                             stride=1, padding=1, bias=use_bias),
                   norm_layer(ngf*16),
                   nn.ReLU(0.2),
                   nn.MaxPool2d(2, stride=2),
                   norm_layer(ngf*16),
                   nn.ReLU(0.2)]
        model04 += [nn.Conv2d(ngf*16, ngf*12, kernel_size=3,
                             stride=1, padding=1, bias=use_bias),
                   norm_layer(ngf*12),
                   nn.ReLU(0.2),
                   nn.MaxPool2d(2, stride=2),
                   norm_layer(ngf*12),
                   nn.ReLU(0.2)]

        model2 = [nn.ConvTranspose2d(ngf*12, ngf*16, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*16),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*16, ngf*16, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*16),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*16, ngf*16, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*16),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*16, ngf*12, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*12),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*12, ngf*8, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*8),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*8, ngf*4, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf*4),
                   nn.ReLU(True)]
        model2 += [nn.ConvTranspose2d(ngf*4, ngf, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   norm_layer(ngf),
                   nn.ReLU(True)]
        model2 += [nn.Conv2d(ngf, 3, kernel_size=4,
                             stride=2, padding=1, bias=use_bias),
                   nn.Tanh()]  

        self.model01 = nn.Sequential(*model01)
        self.model02 = nn.Sequential(*model02)
        self.model03 = nn.Sequential(*model03)
        self.model04 = nn.Sequential(*model04)
        self.model2 = nn.Sequential(*model2)

    def forward(self, x, cond=0):
        #Reconstruction Only
        if self.cond == 0:
            enc = self.model04(self.model03(self.model02(self.model01(x))))
            dec = self.model2(enc)
            return dec
        #Encode Only
        if self.cond == 1:
            enc = self.model04(self.model03(self.model02(self.model01(x))))
            return enc
        #Decoder Only
        if self.cond == 2:
            dec = self.model2(x)
            return dec
        #Pixel_features
        if self.cond == 3:
            f01 = self.model01(x)
            f02 = self.model02(f01)
            f03 = self.model03(f02)
            f04 = self.model04(f03)
            return f01, f02, f03, f04

Here I make use of the cond variable value to get the encoder and decoder as follows:这里我利用cond变量值来获取编码器和解码器,如下所示:

encoder = AutoEncoderNew(3,3,2,1) # cond=1 for encoder
encoder.load_state_dict(auto_model.state_dict())

Is this the correct way to go about it?这是go的正确方法吗? Is there anything wrong with this approach?这种方法有什么问题吗?

Without going on into your code, i'll briefly explain in general, with kind of a psuedo code.在不深入您的代码的情况下,我将简要解释一下,用一种伪代码。 My syntax is obviously not correct.我的语法显然不正确。 Hope you get the idea希望你能明白

class AutoEncoderNew(nn.Module):
    def __init__(self):
        self.Encoder = nn.Sequential(Conv2d,Norm,Relu,...,Conv2d,Norm,Relu)
        self.Decoder = nn.Sequential(ConvTranspose2d,Norm,Relu,...,ConvTranspose2d,Norm,Relu)

    def forward(x,self):
        encoded_x = Encoder(x)
        decoded_x = Decoder(encoded_x)
        return decoded_x 

later when u load your model稍后当你加载你的 model

autencoder = AutoEncoder() 
autoencoder.load_state_dict(auto_model.state_dict())

you can now call a specific block in your loaded model as follow:您现在可以在加载的 model 中调用特定块,如下所示:

encoder = autoencoder.Encoder
decoder = autoencoder.Decoder

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM