[英]Unable to load model from checkpoint in Pytorch-Lightning
I am working with a U-Net in Pytorch Lightning.我正在 Pytorch Lightning 中使用 U-Net。 I am able to train the model successfully but after training when I try to load the model from checkpoint I get this error:
我能够成功训练模型,但训练后当我尝试从检查点加载模型时出现此错误:
Complete Traceback:完整回溯:
Traceback (most recent call last):
File "src/train.py", line 269, in <module>
main(sys.argv[1:])
File "src/train.py", line 263, in main
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state
model = cls(*cls_args, **cls_kwargs)
File "src/train.py", line 162, in __init__
self.inc = double_conv(self.n_channels, 64)
File "src/train.py", line 122, in double_conv
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 406, in __init__
super(Conv2d, self).__init__(
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 50, in __init__
if in_channels % groups != 0:
TypeError: unsupported operand type(s) for %: 'dict' and 'int'
I tried surfing the github issues and forums, am not able to figure out what the issue is.我尝试浏览 github 问题和论坛,但无法弄清楚问题是什么。 Please help.
请帮忙。
Here's the code of my model and the checkpoint loading step:这是我的模型和检查点加载步骤的代码:
Model:模型:
class Unet(pl.LightningModule):
def __init__(self, n_channels, n_classes=5):
super(Unet, self).__init__()
# self.hparams = hparams
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = True
self.logger = WandbLogger(name="Adam", project="pytorchlightning")
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def down(in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(2), double_conv(in_channels, out_channels)
)
class up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=False):
super().__init__()
if bilinear:
self.up = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True
)
else:
self.up = nn.ConvTranspose2d(
in_channels // 2, in_channels // 2, kernel_size=2, stride=2
)
self.conv = double_conv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# [?, C, H, W]
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
self.inc = double_conv(self.n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
return self.out(x)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat, y)
# wandb_logger.log_metrics({"loss":loss})
return {"loss": loss}
def training_epoch_end(self, outputs):
avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.logger.log_metrics({"train_loss": avg_train_loss})
return {"average_loss": avg_train_loss}
def test_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat, y)
return {"test_loss": loss, "pred": y_hat}
def test_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
return {"avg_test_loss": avg_loss}
def MSE(self, logits, labels):
return torch.mean((logits - labels) ** 2)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1, weight_decay=1e-8)
Main Function:主功能:
def main(expconfig):
# Define checkpoint callback
checkpoint_callback = ModelCheckpoint(
filepath="/home/africa_wikilimo/data/model_checkpoint/",
save_top_k=1,
verbose=True,
monitor="loss",
mode="min",
prefix="",
)
# Initialise datasets
print("Initializing Climate Dataset....")
clima_train = Clima_Dataset(expconfig[0])
# Initialise dataloaders
print("Initializing train_loader....")
train_dataloader = DataLoader(clima_train, batch_size=2, num_workers=4)
# Initialise model and trainer
print("Initializing model...")
model = Unet(n_channels=9, n_classes=5)
print("Initializing Trainer....")
if torch.cuda.is_available():
model.cuda()
trainer = pl.Trainer(
max_epochs=1,
gpus=1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=None,
)
else:
trainer = pl.Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)
trainer.fit(model, train_dataloader=train_dataloader)
print(checkpoint_callback.best_model_path)
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
Cause原因
This happens because your model is unable to load hyperparameters(n_channels, n_classes=5) from the checkpoint as you do not save them explicitly.发生这种情况是因为您的模型无法从检查点加载超参数(n_channels,n_classes=5),因为您没有明确保存它们。
Fix使固定
You can resolve it by using the self.save_hyperparameters('n_channels', 'n_classes')
method in your Unet class's init method.您可以通过在 Unet 类的init方法中使用
self.save_hyperparameters('n_channels', 'n_classes')
方法来解决它。 Refer PyTorch Lightning hyperparams-docs for more details on the use of this method.有关使用此方法的更多详细信息,请参阅PyTorch Lightning hyperparams-docs 。 Use of save_hyperparameters lets the selected params to be saved in the hparams.yaml along with the checkpoint.
使用 save_hyperparameters 可以将选定的参数与检查点一起保存在hparams.yaml 中。
Thanks @Adrian Wälchli (awaelchli) from the PyTorch Lightning core contributors team who suggested this fix, when I faced the same issue.感谢来自 PyTorch Lightning 核心贡献者团队的 @Adrian Wälchli (awaelchli),当我遇到同样的问题时,他提出了此修复程序。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.