简体   繁体   English

如何使用pytorch中的checkpoint模型文件来测试CIFAR-10数据集?

[英]How to use checkpoint model file in pytorch to test the CIFAR-10 dataset?

model = SqueezeNext()
model = model.to(device)

def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'):
# Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    losslogger = checkpoint['losslogger']
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))


return model, optimizer, start_epoch, losslogger

model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)

TypeError: Traceback (most recent call last) in () 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False) 42 ---> 43 model = SqueezeNext() 44 model = model.to(device) 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: init () missing 3 required positional arguments: 'width_x', 'blocks', and 'num_classes' TypeError: Traceback (最近调用 last) in () 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False) 42 ---> 43 model = SqueezeNext() 44 model = model.to(device) 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: init () 缺少 3 个必需的位置参数:'width_x'、'blocks' 和 'num_classes'

I think I am not implementing this in right manner!!我想我没有以正确的方式实施这个!!

Your error isn't from your checkpoint function.您的错误不是来自您的检查点功能。 If we look at the traceback:如果我们看一下回溯:

> TypeError: Traceback (most recent call last)
> <ipython-input-51-94c8be648862> in <module>()
>      41 test_loader   = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False)
>      42 
> ---> 43 model = SqueezeNext()
>      44 model = model.to(device)
>      45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: __init__() missing 3
> required positional arguments: 'width_x', 'blocks', and 'num_classes'

The line that we're being told is breaking the line 43:我们被告知的这条线正在打破第 43 行:

> ---> 43 model = SqueezeNext()

And the error is:错误是:

> required positional arguments: 'width_x', 'blocks', and 'num_classes'

I'm assuming you're using this implementation of SqueezeNext, but whichever implementation you're using, you're not passing all the arguments needed to initialise the model.我假设您正在使用 SqueezeNext 的这个实现,但是无论您使用哪个实现,您都没有传递初始化模型所需的所有参数。 You'll want to change your code to something like:您需要将代码更改为以下内容:

model = SqueezeNext(width_x=1.0, blocks=[6, 6, 8, 1], num_classes=10)

If you're not using that implementation, you'll need to find the source code for the SqueezeNext model, and see what arguments the __init__ function requires.如果您没有使用该实现,则需要找到SqueezeNext模型的源代码,并查看__init__函数需要哪些参数。 You can try this:你可以试试这个:

import inspect

inspect.signature(SqueezeNext.__init__)

Which should give you the signature.这应该给你签名。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM