简体   繁体   中英

How to use pretrained weights of a model for initializing the weights in next iteration?

I have a model architecture. I have saved the entire model using torch.save() for some n number of iterations. I want to run another iteration of my code by using the pre-trained weights of the model I saved previously.

Edit : I want the weight initialization for the new iteration be done from the weights of the pretrained model

Edit 2: Just to add, I don't plan to resume training. I intend to save the model and use it for a separate training with same parameters. Think of it like using a saved model with weights etc. for a larger run and more samples (ie a complete new training job)

Right now, I do something like:

# default_lr = 5
# default_weight_decay = 0.001
# model_io = the pretrained model 
model = torch.load(model_io) 
optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)  
loss_new = BCELoss()  
epochs = default_epoch 
.
.
training_loop():
....
outputs = model(input)
....
.
#similarly for test loop

Am I missing something? I have to run for a very long epoch for a huge number of sample so can not afford to wait to see the results then figure out things.

Thank you!

From the code that you have posted, I see that you are only loading the previous model parameters in order to restart your training from where you left it off. This is not sufficient to restart your training correctly. Along with your model parameters (weights), you also need to save and load your optimizer state, especially when your choice of optimizer is Adam which has velocity parameters for all your weights that help in decaying the learning rate.

In order to smoothly restart training, I would do the following:

# For saving your model

state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}
model_save_path = "Enter/your/model/path/here/model_name.pth"
torch.save(state, model_save_path)

# ------------------------------------------

# For loading your model
state = torch.load(model_save_path)

model = MyNetwork()
model.load_state_dict(state['model'])

optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)
optim.load_state_dict(state['optimizer'])

Besides these, you may also want to save your learning rate if you are using a learning rate decay strategy, your best validation accuracy so far which you may want for checkpointing purposes, and any other changeable parameter which might affect your training. But in most of the cases, saving and loading just the model weights and optimizer state should be sufficient.

EDIT: You may also want to look at this following answer which explains in detail how you should save your model in different scenarios.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM