简体   繁体   English

无法从 Pytorch-Lightning 中的检查点加载模型

[英]Unable to load model from checkpoint in Pytorch-Lightning

I am working with a U-Net in Pytorch Lightning.我正在 Pytorch Lightning 中使用 U-Net。 I am able to train the model successfully but after training when I try to load the model from checkpoint I get this error:我能够成功训练模型,但训练后当我尝试从检查点加载模型时出现此错误:

Complete Traceback:完整回溯:

Traceback (most recent call last):
  File "src/train.py", line 269, in <module>
    main(sys.argv[1:])
  File "src/train.py", line 263, in main
    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "src/train.py", line 162, in __init__
    self.inc = double_conv(self.n_channels, 64)
  File "src/train.py", line 122, in double_conv
    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 406, in __init__
    super(Conv2d, self).__init__(
  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 50, in __init__
    if in_channels % groups != 0:
TypeError: unsupported operand type(s) for %: 'dict' and 'int'

I tried surfing the github issues and forums, am not able to figure out what the issue is.我尝试浏览 github 问题和论坛,但无法弄清楚问题是什么。 Please help.请帮忙。

Here's the code of my model and the checkpoint loading step:这是我的模型和检查点加载步骤的代码:
Model:模型:

class Unet(pl.LightningModule):
    def __init__(self, n_channels, n_classes=5):
        super(Unet, self).__init__()
        # self.hparams = hparams

        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = True
        self.logger = WandbLogger(name="Adam", project="pytorchlightning")

        def double_conv(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )

        def down(in_channels, out_channels):
            return nn.Sequential(
                nn.MaxPool2d(2), double_conv(in_channels, out_channels)
            )

        class up(nn.Module):
            def __init__(self, in_channels, out_channels, bilinear=False):
                super().__init__()

                if bilinear:
                    self.up = nn.Upsample(
                        scale_factor=2, mode="bilinear", align_corners=True
                    )
                else:
                    self.up = nn.ConvTranspose2d(
                        in_channels // 2, in_channels // 2, kernel_size=2, stride=2
                    )

                self.conv = double_conv(in_channels, out_channels)

            def forward(self, x1, x2):
                x1 = self.up(x1)
                # [?, C, H, W]
                diffY = x2.size()[2] - x1.size()[2]
                diffX = x2.size()[3] - x1.size()[3]

                x1 = F.pad(
                    x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
                )
                x = torch.cat([x2, x1], dim=1)
                return self.conv(x)

        self.inc = double_conv(self.n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)

        x2 = self.down1(x1)

        x3 = self.down2(x2)

        x4 = self.down3(x3)

        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return self.out(x)

    def training_step(self, batch, batch_nb):
        x, y = batch

        y_hat = self.forward(x)
        loss = self.MSE(y_hat, y)

        # wandb_logger.log_metrics({"loss":loss})
        return {"loss": loss}

    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.logger.log_metrics({"train_loss": avg_train_loss})
        return {"average_loss": avg_train_loss}

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.MSE(y_hat, y)
        return {"test_loss": loss, "pred": y_hat}

    def test_end(self, outputs):

        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()

        return {"avg_test_loss": avg_loss}

    def MSE(self, logits, labels):

        return torch.mean((logits - labels) ** 2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1, weight_decay=1e-8)

Main Function:主功能:

def main(expconfig):
    # Define checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        filepath="/home/africa_wikilimo/data/model_checkpoint/",
        save_top_k=1,
        verbose=True,
        monitor="loss",
        mode="min",
        prefix="",
    )

    # Initialise datasets
    print("Initializing Climate Dataset....")
    clima_train = Clima_Dataset(expconfig[0])

    # Initialise dataloaders
    print("Initializing train_loader....")
    train_dataloader = DataLoader(clima_train, batch_size=2, num_workers=4)

    # Initialise model and trainer
    print("Initializing model...")
    model = Unet(n_channels=9, n_classes=5)
    print("Initializing Trainer....")
    if torch.cuda.is_available():

        model.cuda()

        trainer = pl.Trainer(
            max_epochs=1,
            gpus=1,
            checkpoint_callback=checkpoint_callback,
            early_stop_callback=None,
        )
    else:

        trainer = pl.Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)
    
    trainer.fit(model, train_dataloader=train_dataloader)
    print(checkpoint_callback.best_model_path)
    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)

Cause原因

This happens because your model is unable to load hyperparameters(n_channels, n_classes=5) from the checkpoint as you do not save them explicitly.发生这种情况是因为您的模型无法从检查点加载超参数(n_channels,n_classes=5),因为您没有明确保存它们。

Fix使固定

You can resolve it by using the self.save_hyperparameters('n_channels', 'n_classes') method in your Unet class's init method.您可以通过在 Unet 类的init方法中使用self.save_hyperparameters('n_channels', 'n_classes')方法来解决它。 Refer PyTorch Lightning hyperparams-docs for more details on the use of this method.有关使用此方法的更多详细信息,请参阅PyTorch Lightning hyperparams-docs Use of save_hyperparameters lets the selected params to be saved in the hparams.yaml along with the checkpoint.使用 save_hyperparameters 可以将选定的参数与检查点一起保存在hparams.yaml 中

Thanks @Adrian Wälchli (awaelchli) from the PyTorch Lightning core contributors team who suggested this fix, when I faced the same issue.感谢来自 PyTorch Lightning 核心贡献者团队的 @Adrian Wälchli (awaelchli),当我遇到同样的问题时,他提出了此修复程序。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 将 pytorch_model.bin 拆分为块后,无法从 pytorch 检查点加载权重 - Unable to load weights from pytorch checkpoint after splitting pytorch_model.bin into chunks 从变压器加载 model 时出现“无法从 pytorch 检查点文件加载权重” - Getting "Unable to load weights from pytorch checkpoint file" when loading model from transformers 错误然后导入pytorch-lightning,azure notebook - error then import pytorch-lightning, azure notebook Pytorch-Lightning 配置错误异常; 关闭尚未执行 - Pytorch-Lightning Misconfiguration Exception; The closure hasn't been executed Pytorch-Lightning ModelCheckpoint 获取保存检查点的路径 - Pytorch-Lightning ModelCheckpoint get paths of saved checkpoints 如何在pytorch模型中加载检查点文件? - How to load a checkpoint file in a pytorch model? 如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵? - How to dump confusion matrix using TensorBoard logger in pytorch-lightning? Pytorch-Lightning 是否具有多处理(或 Joblib)模块? - Does Pytorch-Lightning have a multiprocessing (or Joblib) module? Conda 正在安装一个非常旧版本的 pytorch-lightning - Conda is installing a very old version of pytorch-lightning 如何从检查点文件加载经过微调的 pytorch huggingface bert model? - How to load a fine tuned pytorch huggingface bert model from a checkpoint file?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM