繁体   English   中英

加载 PyTorch NN model 的检查点时的异常

[英]exceptions when loading the checkpoint of a PyTorch NN model

当调用以下单元格中定义的 function 时,抛出异常' TypeError: forward() 需要 2 个位置 arguments 但给出了 9 个'文档提供了更多详细信息

def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model2 = model1(checkpoint['input_size'],
              checkpoint['output_size'],
              checkpoint['fc1'],
              checkpoint['fc2'],
              checkpoint['optimizer_state_dict'],
              checkpoint['epoch'],
              checkpoint['class_to_idx'],
              checkpoint['learning_rate'])
model2.load_state_dict(checkpoint['state_dict'])
return model2

写出检查点的代码如下:

checkpoint ={'input_size':512,
         'output_size':102,
         'fc1':256,
         'fc2':102,
         'state_dict': model.state_dict(),
         'optimizer_state_dict': optimizer.state_dict(),
         'epoch': epoch+1,
         'class_to_idx': model.class_to_idx,
         'learning_rate': 0.003}
torch.save(checkpoint,chkptJP)

您的错误表明model1是已经实例化的网络,而它应该是 class。 有关全面信息,请参阅有关保存的官方文档(如有疑问,请始终参考)。 我将在整个答案中链接到它,因此请务必检查并了解发生了什么。

保存常规检查点

您的代码保存了一个通用检查点 你可以用这种方式保存任何字典和任何你想要的信息(它基本上是Python 的泡菜,你也可以类似地调整它)。 您的信息很多,其中一些与 model 本身无关。

加载常规检查点

正如您所做的那样,您可以通过torch.load加载所有这些数据。 由于您保存了state_dict (权重),而不是整个Model (它的代码看起来如何),您必须使用随机权重创建 model a-new 并在之后加载它们。

这段代码应该没问题:

def load_checkpoint(chkptJP):
    checkpoint = torch.load(chkptJP)
    model = ModelClass(
        checkpoint["input_size"],
        checkpoint["output_size"],
        checkpoint["fc1"],
        checkpoint["fc2"],
        checkpoint["optimizer_state_dict"],
        checkpoint["epoch"],
        checkpoint["class_to_idx"],
        checkpoint["learning_rate"],
    )
    model.load_state_dict(checkpoint["state_dict"])
    return model

注意ModelClass必须是 class,而不是像您在此处所做的 object。 如果model1是 object,则运行 model1 model1(arg1, ..., arg9)将调用它的__call__方法,如果model1torch.nn.Module的实例,则该方法又是一个包装好的forward方法。 ModelClass在您的代码中应该是这样的(并且可能在某处定义):

import torch


class ModelClass(torch.nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        fc1,
        fc2,
        optimizer_state_dict,
        epoch,
        class_to_idx,
        learning_rate,
    ):
        # Your initialization code here
        ...

    def forward(tensor):
        # Your forward pass here
        ...

如果您在任何地方都没有ModelClass ,则必须单独保存整个 model(例如, torch.save(model)而不是torch.save(model.state_dict()))并将其作为一个整体加载( torch.load(PATH)代替chkp=torch.load(PATH)后跟model.load_state_dict调用实例)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM