[英]How can I load and use a PyTorch (.pth.tar) model
我对 Torch 不是很熟悉,我主要使用 Tensorflow。 但是,我需要使用在 Torch 中重新训练的重新训练的初始模型。 由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。
该模型保存为.pth.tar
文件。
我希望能够首先加载这个模型。 到目前为止,我已经能够弄清楚我必须使用以下内容:
model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')
这似乎有效,因为print(model)
打印出大量数字和其他值,我认为这些值是权重和偏差的值。
在此之后,我需要能够用它对图像进行分类。 我一直无法弄清楚这一点。 我必须如何格式化图像? 图像是否应该转换为数组? 在此之后,我必须如何将输入数据传递给网络?
你基本上需要做与张量流相同的事情。 也就是说,当您存储网络时,只会存储参数(即网络中的可训练对象),而不是“胶水”,这就是使用训练模型所需的全部逻辑。 因此,如果您有.pth.tar
文件,则可以加载它,从而覆盖已定义模型的参数值。
这意味着保存/加载模型的一般过程如下:
nn.Module
对象)torch.save
保存参数nn.Module
对象的相同定义来首先实例化 pytorch 网络torch.load
覆盖网络参数的值这是有关如何执行此操作的一些参考资料的讨论: pytorch forums
这是一个超短的 mwe:
# to store
torch.save({
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')
# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.