[英]Error in loading model in PyTorch
我有以下代码片段
from train import predict
import random
import torch
ann=torch.load('ann.pt') #importing trained model
while True:
k=raw_input("User:")
intent,top_value,top_index = predict(str(k),ann)
print(intent)
当我运行脚本时,它会引发如下错误:
Traceback (most recent call last):
File "test.py", line 6, in <module>
ann=torch.load('ann.pt') #importing trained model
File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
return _load(f, map_location, pickle_module)
File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'
我的脚本位于同一文件夹中的ann.pt文件。 请帮助我确定修复错误并加载模型。 提前致谢。
尝试同时保存参数和模型时,pytorch会腌制参数,但仅存储模型Class的路径。 例如,更改树结构或重构可能会破坏加载。 因此,正如文档所指出的那样 ,不建议您仅使用保存/加载参数:
...序列化的数据绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。
要获得更多帮助,显示保存代码将很有用。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.