简体   繁体   English

如何在 pytorch 中评估训练有素的 model?

[英]How to evaluate a trained model in pytorch?

I have trained a model and save model using torch.save.我训练了一个 model 并使用 torch.save 保存了 model。 Then after training I have loaded the model using train.load but I am getting this error然后在训练后我使用 train.load 加载了 model 但我收到了这个错误


Traceback (most recent call last):
  File "/home/fsdfs.py", line 219, in <module>
    test(model, 'cuda', testloader)
  File "/home/fsdfs.py", line 201, in test
    model.eval()
AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

Here is my code for test part这是我的测试部分代码

model = torch.load("train_5.pth")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to('cuda'), target.to('cuda')
            output = model(data)
            #test_loss += f.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            print(pred, target)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
         correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


test(model, 'cuda', testloader)

I have commented training part of the code in the file, so in a way this and loading the data part is all that is there in the file now.我已经评论了文件中代码的训练部分,所以在某种程度上,这个和加载数据部分就是文件中现在的全部内容。

What am I doing wrong?我究竟做错了什么?

Like @jodag has said.就像@jodag 所说的那样。 you probably have saved a state_dict instead of a model, which is recommended by the community as well.您可能保存了 state_dict 而不是社区推荐的 model。

This link explains the difference between two. 此链接解释了两者之间的区别。 To keep my answer self contained, I copy the snippet from the documentation.为了使我的答案独立,我从文档中复制了片段。 Here is the recommended way:这是推荐的方法:

Save:救:

torch.save(model.state_dict(), PATH)

Load:加载:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

You could also save the entire model instead of saving the state_dict, if you really need to use the model the way you do.如果您确实需要按照您的方式使用 model,您也可以保存整个 model 而不是保存 state_dict。

Save:救:

torch.save(model, PATH)

Load:加载:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM