简体   繁体   English

在PyTorch中加载模型时出错

[英]Error in loading model in PyTorch

I Have the following code snippet 我有以下代码片段

from train import predict
import random
import torch


ann=torch.load('ann.pt') #importing trained model


while True:
      k=raw_input("User:")
      intent,top_value,top_index = predict(str(k),ann)
      print(intent)

when I run the script it is throwing the error as below: 当我运行脚本时,它会引发如下错误:

Traceback (most recent call last):
  File "test.py", line 6, in <module>
    ann=torch.load('ann.pt') #importing trained model
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
    return _load(f, map_location, pickle_module)
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
    result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'

I have ann.pt file in the same folder as my script is. 我的脚本位于同一文件夹中的ann.pt文件。 Kindly help me identify fix the error and load the model. 请帮助我确定修复错误并加载模型。 Thanks in advance. 提前致谢。

When trying to save both parameters and model, pytorch pickles the parameters but only store path the model Class. 尝试同时保存参数和模型时,pytorch会腌制参数,但仅存储模型Class的路径。 For instance, changing tree structure or refactoring can break loading. 例如,更改树结构或重构可能会破坏加载。 Therefore as the documentation points out , it is not recommended, prefer only save/load parameters: 因此,正如文档所指出的那样 ,不建议您仅使用保存/加载参数:

...the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors. ...序列化的数据绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。

For more help, it'll be useful to show your saving code. 要获得更多帮助,显示保存代码将很有用。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM