繁体   English   中英

PyTorch:如何定义一个利用迁移学习的新神经网络

[英]PyTorch: How to define a new neural network that utilizes transfer learning

我正在从 Keras/TF 框架迁移,并且在理解 PyTorch 中的迁移学习过程时遇到了一些麻烦。

我想使用 pytorch-lightning 框架,并且想在一个脚本中在不同的神经网络之间切换。

根据这个例子,我们可以在不同的神经网络之间进行切换:

class BERT(pl.LightningModule):
def __init__(self, model_name, task):
    self.task = task

    if model_name == 'transformer':
        self.net = Transformer()
    elif model_name == 'my_cool_version':
        self.net = MyCoolVersion()

问题是:如何创建一个新的神经网络来扩展 nn.Module 并利用迁移学习过程?

我自己的实现看起来像这样:我使用 vgg16 网络,并用两个输出神经元仅用一个 fc 替换了分类器层。

class VGGNetwork(nn.Module):
    def __init__(self):
        super(VGGNetwork, self).__init__()
        # vgg16 is the default model here, we can use bn etc...
        self.model = vgg16(pretrained=True)

        # removing the last three layers of classifier only 2 ...
        self.model.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 2))

def forward(self, x):
    return self.model.forward(x)

这是正确的方法吗?

除了最后一层之外,您可以冻结神经网络层的权重和 bais。

你可以使用requires_grad = False

for param in model_conv.parameters():
    param.requires_grad = False

您可以在以下链接中找到更多相关信息https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

https://pytorch-lightning.readthedocs.io/en/0.7.1/transfer_learning.html

    ...

class AutoEncoder(pl.LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(pl.LightingModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM