繁体   English   中英

如何使用pytorch中的checkpoint模型文件来测试CIFAR-10数据集?

[英]How to use checkpoint model file in pytorch to test the CIFAR-10 dataset?

model = SqueezeNext()
model = model.to(device)

def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'):
# Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    losslogger = checkpoint['losslogger']
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))


return model, optimizer, start_epoch, losslogger

model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)

TypeError: Traceback (最近调用 last) in () 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False) 42 ---> 43 model = SqueezeNext() 44 model = model.to(device) 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: init () 缺少 3 个必需的位置参数:'width_x'、'blocks' 和 'num_classes'

我想我没有以正确的方式实施这个!!

您的错误不是来自您的检查点功能。 如果我们看一下回溯:

> TypeError: Traceback (most recent call last)
> <ipython-input-51-94c8be648862> in <module>()
>      41 test_loader   = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False)
>      42 
> ---> 43 model = SqueezeNext()
>      44 model = model.to(device)
>      45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: __init__() missing 3
> required positional arguments: 'width_x', 'blocks', and 'num_classes'

我们被告知的这条线正在打破第 43 行:

> ---> 43 model = SqueezeNext()

错误是:

> required positional arguments: 'width_x', 'blocks', and 'num_classes'

我假设您正在使用 SqueezeNext 的这个实现,但是无论您使用哪个实现,您都没有传递初始化模型所需的所有参数。 您需要将代码更改为以下内容:

model = SqueezeNext(width_x=1.0, blocks=[6, 6, 8, 1], num_classes=10)

如果您没有使用该实现,则需要找到SqueezeNext模型的源代码,并查看__init__函数需要哪些参数。 你可以试试这个:

import inspect

inspect.signature(SqueezeNext.__init__)

这应该给你签名。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM