简体   繁体   中英

How to use a bert pretrained model somewhere else?

I followed this course https://www.coursera.org/learn/sentiment-analysis-bert about building a pretrained model for sentiment analysis. During the trining, at each epoch they saved the model using torch.save(model.state_dict(), f'BERT_ft_epoch{epoch}.model') . Now I want to use one of these models (the best one obviously) elsewhere, for example where a user can paste a tweet as an input and get the emotion of the writer. But I don't know how to load the model and predict, here's what I tried:

import torchvision.models as models
import torch

model = models.resnet101(pretrained=False)
model.load_state_dict(torch.load('Models/BERT_ft_epoch15.model'), strict=False)
model_ft.eval()
output = model_ft(input) #input is a tweets list

I get this error: TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not list

resnet101 and BERT are two totally different models. You cannot load a pretrained BERT model into resnet.

How to define, initialize, save and load models using Pytorch.

Initializing a model. That is done inheriting the class nn.Module , consider the simple two layer model:

import torch
import torch.nn as nn

class Model(nn.Module)
    def __init__(self, input_size=128, output_size=10):
        super(Model).__init__()
    
        self.layer1 = nn.Sequetial(nn.Linear(input_size, 64), nn.LeakyReLU())
        self.layer2 = nn.Linear(64, output_size)
    
    def forward(self, x):
        y = self.layer2(self.layer1(x))
        return y

The layers of the model are first initialized at __init__() and then we specify the operations of the forward pass in forward() . You can be creative there, just remember of using pytorch differenciable operations.

You initialize the model by creating an instance of the new class:

model = Model() # brand new instance!

After training your model you want to save it:

import torch
model = Model(128, 10) # initialization

torch.save(model.state_dict, 'model.pt') # saving state dict

You are not saving the model here, you are saving the state_dict this is an ordered dictionary that contains all the weights and biases and other parameters of your model. The reason we save the state_dict instead of the model directly can be found in the documentation ( https://pytorch.org/tutorials/beginner/saving_loading_models.html ). For now, just consider it best practice.

Finally, we arrive at how to load the model. You have to initialize the model first, then load the state_dict from disk.

model = Model(128, 10) # model initialization
model.load_state_dict('model.pt')
model.eval() # put the model in inference mode

Notice that, when we save the state_dict we may also save the optimizer and the graph used for back propagation. That is useful to checkpoint the training and resume it at a later stage.

    # in the training loop
    torch.save({"epoch": epoch,
                "model": model.state_dict,
                "optim": optim.state_dict,
                "loss": loss}, f'checkpoint{epoch}.pt')

I hope that paints a cleares picture for you =)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM