[英]Getting "Unable to load weights from pytorch checkpoint file" when loading model from transformers
[英]exceptions when loading the checkpoint of a PyTorch NN model
当调用以下单元格中定义的 function 时,抛出异常' TypeError: forward() 需要 2 个位置 arguments 但给出了 9 个'本文档提供了更多详细信息
def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model2 = model1(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['fc1'],
checkpoint['fc2'],
checkpoint['optimizer_state_dict'],
checkpoint['epoch'],
checkpoint['class_to_idx'],
checkpoint['learning_rate'])
model2.load_state_dict(checkpoint['state_dict'])
return model2
写出检查点的代码如下:
checkpoint ={'input_size':512,
'output_size':102,
'fc1':256,
'fc2':102,
'state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch+1,
'class_to_idx': model.class_to_idx,
'learning_rate': 0.003}
torch.save(checkpoint,chkptJP)
您的错误表明model1
是已经实例化的网络,而它应该是 class。 有关全面信息,请参阅有关保存的官方文档(如有疑问,请始终参考)。 我将在整个答案中链接到它,因此请务必检查并了解发生了什么。
您的代码保存了一个通用检查点。 你可以用这种方式保存任何字典和任何你想要的信息(它基本上是Python 的泡菜,你也可以类似地调整它)。 您的信息很多,其中一些与 model 本身无关。
正如您所做的那样,您可以通过torch.load
加载所有这些数据。 由于您保存了state_dict
(权重),而不是整个Model
(它的代码看起来如何),您必须使用随机权重创建 model a-new 并在之后加载它们。
这段代码应该没问题:
def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model = ModelClass(
checkpoint["input_size"],
checkpoint["output_size"],
checkpoint["fc1"],
checkpoint["fc2"],
checkpoint["optimizer_state_dict"],
checkpoint["epoch"],
checkpoint["class_to_idx"],
checkpoint["learning_rate"],
)
model.load_state_dict(checkpoint["state_dict"])
return model
注意ModelClass
必须是 class,而不是像您在此处所做的 object。 如果model1
是 object,则运行 model1 model1(arg1, ..., arg9)
将调用它的__call__
方法,如果model1
是torch.nn.Module
的实例,则该方法又是一个包装好的forward
方法。 ModelClass
在您的代码中应该是这样的(并且可能在某处定义):
import torch
class ModelClass(torch.nn.Module):
def __init__(
self,
input_size,
output_size,
fc1,
fc2,
optimizer_state_dict,
epoch,
class_to_idx,
learning_rate,
):
# Your initialization code here
...
def forward(tensor):
# Your forward pass here
...
如果您在任何地方都没有ModelClass
,则必须单独保存整个 model(例如, torch.save(model)
而不是torch.save(model.state_dict()))
并将其作为一个整体加载( torch.load(PATH)
代替chkp=torch.load(PATH)
后跟model.load_state_dict
调用实例)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.