繁体   English   中英

如何保存 model 的训练权重检查点并从 PyTorch 的最后一点继续训练?

[英]How to save training weight checkpoint of model and continue training from last point in PyTorch?

我正在尝试在一定数量的时期后保存经过训练的 model 的检查点权重,并使用 PyTorch 继续从最后一个检查点训练到另一个时期的数量为了实现这一点,我编写了如下脚本

训练 model:

def create_model():
  # load model from package
  model = smp.Unet(
      encoder_name="resnet152",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
      encoder_weights='imagenet',           # use `imagenet` pre-trained weights for encoder initialization
      in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
      classes=2,                      # model output channels (number of classes in your dataset)
  )
  return model

model = create_model()
model.to(device)
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
epochs = 5

for epoch in range(epochs):
    print('Epoch: [{}/{}]'.format(epoch+1, epochs))

    # train set
    pbar = tqdm(train_loader)
    model.train()
    iou_logger = iouTracker()
    for batch in pbar:
        # load image and mask into device memory
        image = batch['image'].to(device)
        mask = batch['mask'].to(device)

        # pass images into model
        pred = model(image)
        # pred = checkpoint['model_state_dict']

        # get loss
        loss = criteria(pred, mask)

        # update the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # compute and display progress
        iou_logger.update(pred, mask)
        mIoU = iou_logger.get_mean()
        pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))

    # development set
    pbar = tqdm(development_loader)
   
    model.eval()
    iou_logger = iouTracker()
    with torch.no_grad():
        for batch in pbar:
            # load image and mask into device memory
            image = batch['image'].to(device)
            mask = batch['mask'].to(device)

            # pass images into model
            pred = model(image)

            # get loss
            loss = criteria(pred, mask)
            
            # compute and display progress
            iou_logger.update(pred, mask)
            mIoU = iou_logger.get_mean()
            pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))

# save model
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,}, '/content/drive/MyDrive/checkpoint.pt')

由此,我可以将 model 检查点文件保存为checkpoint.pt 5 个时期

为了使用保存的检查点权重文件继续训练,我在下面编写了另一个脚本:

epochs = 5    
for epoch in range(epochs):
    print('Epoch: [{}/{}]'.format(epoch+1, epochs))

    # train set
    pbar = tqdm(train_loader)


    checkpoint = torch.load( '/content/drive/MyDrive/checkpoint.pt')
    print(checkpoint)
    

    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    model.train()
    iou_logger = iouTracker()
    for batch in pbar:
        # load image and mask into device memory
        image = batch['image'].to(device)
        mask = batch['mask'].to(device)

        # pass images into model
        pred = model(image)
        # pred = checkpoint['model_state_dict']

        # get loss
        loss = criteria(pred, mask)

        # update the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # compute and display progress
        iou_logger.update(pred, mask)
        mIoU = iou_logger.get_mean()
        pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))

    # development set
    pbar = tqdm(development_loader)
   
    model.eval()
    iou_logger = iouTracker()
    with torch.no_grad():
        for batch in pbar:
            # load image and mask into device memory
            image = batch['image'].to(device)
            mask = batch['mask'].to(device)

            # pass images into model
            pred = model(image)

            # get loss
            loss = criteria(pred, mask)
            
            # compute and display progress
            iou_logger.update(pred, mask)
            mIoU = iou_logger.get_mean()
            pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))

# save model
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,}, 'checkpoint.pt')

这会引发错误:

RuntimeError                              Traceback (most recent call last)
<ipython-input-31-54f48c10531a> in <module>()


---> 14     model.load_state_dict(checkpoint['model_state_dict'])



/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1222         if len(error_msgs) > 0:
   1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1226 

RuntimeError: Error(s) in loading state_dict for DataParallel:
    Missing key(s) in state_dict: "module.encoder.conv1.weight", "module.encoder.bn1.weight", "module.encoder.bn1.bias", "module.encoder.bn1.running_mean", "module.encoder.bn1.running_var", "module.encoder.layer1.0.conv1.weight", "module.encoder.layer1.0.bn1.weight", "module.encoder.layer1.0.bn1.bias", "module.encoder.layer1.0.bn1.running_mean", "module.encoder.layer1.0.bn1.running_var", "module.encoder.layer1.0.conv2.weight", "module.encoder.layer1.0.bn2.weight", "module.encoder.layer1.0.bn2.bias", "module.encoder.layer1.0.bn2.running_mean", "module.encoder.layer1.0.bn2.running_var", "module.encoder.layer1.0.conv3.weight", "module.encoder.layer1.0.bn3.weight", "module.encoder.layer1.0.bn3.bias", "module.encoder.layer1.0.bn3.running_mean", "module.encoder.layer1.0.bn3.running_var", "module.encoder.layer1.0.downsample.0.weight", "module.encoder.layer1.0.downsample.1.weight", "module.encoder.layer1.0.downsample.1.bias", "module.encoder.layer1.0.downsample.1.running_mean", "module.encoder.layer1.0.downsample.1.running_var", "module.encoder.layer1.1.conv1.weight", "module.encoder.layer1.1.bn1.weight", "module.encoder.layer1.1.bn1.bias", "module.encoder.layer1.1.bn1.running_mean", "module.encoder.layer1.1.bn1.running_var", "module.encoder.layer1.1.conv2.weight", "module.encoder.layer1.1.bn2.weight", "module.encoder.layer1.1.bn2.bias", "module.encoder.layer1.1.bn2.running_mean", "module.encoder.layer1.1.bn2.running_var", "module.encoder.layer1.1.conv3.weight", "module.encoder.layer...
    Unexpected key(s) in state_dict: "encoder.conv1.weight", "encoder.bn1.weight", "encoder.bn1.bias", "encoder.bn1.running_mean", "encoder.bn1.running_var", "encoder.bn1.num_batches_tracked", "encoder.layer1.0.conv1.weight", "encoder.layer1.0.bn1.weight", "encoder.layer1.0.bn1.bias", "encoder.layer1.0.bn1.running_mean", "encoder.layer1.0.bn1.running_var", "encoder.layer1.0.bn1.num_batches_tracked", "encoder.layer1.0.conv2.weight", "encoder.layer1.0.bn2.weight", "encoder.layer1.0.bn2.bias", "encoder.layer1.0.bn2.running_mean", "encoder.layer1.0.bn2.running_var", "encoder.layer1.0.bn2.num_batches_tracked", "encoder.layer1.1.conv1.weight", "encoder.layer1.1.bn1.weight", "encoder.layer1.1.bn1.bias", "encoder.layer1.1.bn1.running_mean", "encoder.layer1.1.bn1.running_var", "encoder.layer1.1.bn1.num_batches_tracked", "encoder.layer1.1.conv2.weight", "encoder.layer1.1.bn2.weight", "encoder.layer1.1.bn2.bias", "encoder.layer1.1.bn2.running_mean", "encoder.layer1.1.bn2.running_var", "encoder.layer1.1.bn2.num_batches_tracked", "encoder.layer1.2.conv1.weight", "encoder.layer1.2.bn1.weight", "encoder.layer1.2.bn1.bias", "encoder.layer1.2.bn1.running_mean", "encoder.layer1.2.bn1.running_var", "encoder.layer1.2.bn1.num_batches_tracked", "encoder.layer1.2.conv2.weight", "encoder.layer1.2.bn2.weight", "encoder.layer1.2.bn2.bias", "encoder.layer1.2.bn2.running_mean", "encoder.layer1.2.bn2.running_var", "encoder.layer1.2.bn2.num_batches_tracked", "encoder.layer2.0.conv1.weight", "encoder.layer...

我究竟做错了什么? 我怎样才能解决这个问题? 对此的任何帮助都会有所帮助。

这一行:

model.load_state_dict(checkpoint['model_state_dict'])

应该是这样的:

model.load_state_dict(checkpoint)

您需要创建一个新的 model object 来加载 state 字典。 正如官方指南中所建议的那样。

所以在你进行第二个训练阶段之前,


model = create_model()
model.load_state_dict(checkpoint['model_state_dict'])

# then start the training loop

您正在 epoch 循环中加载 state 字典。 您需要在循环之前加载它...

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM