繁体   English   中英

在 pytorch 中解构和重建预训练网络

[英]deconstruct and reconstruct a pretrained network in pytorch

我想使用我在这里找到的 G 网络: https://github.com/scaleway/frontalization/blob/master/network.py要点: https://github.com/scaleway/frontalization/树/主/预训练/generator_v0.pt

我知道它是由编码器(从图像 128x128 到矢量 512)和解码器(从矢量到图像)组成的。 我需要完全相反:从 512 矢量到 128x128 再到矢量

我可以轻松反转编码器/解码器吗? 我可以在第一次操作后使用预训练点的权重吗?

我目前的解决方案

首先,我创建了一个新的 model class 切换编码器和解码器

class G2(nn.Module):
    
    def __init__(self):
        super(G2, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), # Output HxW = 4x4
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), # Output HxW = 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), # Output HxW = 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), # Output HxW = 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), # Output HxW = 64x64
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), # Output HxW = 128x128
            nn.Tanh(),
            # Input HxW = 128x128
            nn.Conv2d(3, 16, 4, 2, 1), # Output HxW = 64x64
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4, 2, 1), # Output HxW = 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # Output HxW = 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), # Output HxW = 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), # Output HxW = 4x4
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1), # Output HxW = 2x2
            nn.MaxPool2d((2,2))
            # At this point, we arrive at our low D representation vector, which is 512 dimensional.
        )
    
    def forward(self, input):
        output = self.main(input)
        return output

我已经加载了预训练的 model 并创建了一个新实例。

netG = torch.load('generator_v0.pt')
netG2 = G2()

并复制图层(和预训练的权重)(不确定这是“优雅”的方式)

for idx in range(len(netG2.main)):
    if(idx<17):
        netG2.main[idx] = netG.main[17+idx]
    else:
        netG2.main[idx] = netG.main[idx-17]

感谢您的意见

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM