简体   繁体   中英

PyTorch: How to define a new neural network that utilizes transfer learning

I am migrating from Keras/TF frameworks and I have litte troubles understanding the transfer learning process in PyTorch.

I want to use pytorch-lightning framework and I want to switch between different neural networks in one script.

Per this example we can switch between different neural networks in their implementation:

class BERT(pl.LightningModule):
def __init__(self, model_name, task):
    self.task = task

    if model_name == 'transformer':
        self.net = Transformer()
    elif model_name == 'my_cool_version':
        self.net = MyCoolVersion()

The question is: how to create a new neural network that extends the nn.Module and utilizes transfer learning process?

My own implementation looks like this: I am using vgg16 network and replaced the classifier layer with only one fc with two output neurons.

class VGGNetwork(nn.Module):
    def __init__(self):
        super(VGGNetwork, self).__init__()
        # vgg16 is the default model here, we can use bn etc...
        self.model = vgg16(pretrained=True)

        # removing the last three layers of classifier only 2 ...
        self.model.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 2))

def forward(self, x):
    return self.model.forward(x)

Is this the correct way how to do that?

you can freeze weights and bais for the neural network layer except for the last layer.

you can use requires_grad = False

for param in model_conv.parameters():
    param.requires_grad = False

you can find more about this at the following link https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

https://pytorch-lightning.readthedocs.io/en/0.7.1/transfer_learning.html

    ...

class AutoEncoder(pl.LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(pl.LightingModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM