繁体   English   中英

在PyTorch中加载模型时出错

[英]Error in loading model in PyTorch

我有以下代码片段

from train import predict
import random
import torch


ann=torch.load('ann.pt') #importing trained model


while True:
      k=raw_input("User:")
      intent,top_value,top_index = predict(str(k),ann)
      print(intent)

当我运行脚本时,它会引发如下错误:

Traceback (most recent call last):
  File "test.py", line 6, in <module>
    ann=torch.load('ann.pt') #importing trained model
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
    return _load(f, map_location, pickle_module)
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
    result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'

我的脚本位于同一文件夹中的ann.pt文件。 请帮助我确定修复错误并加载模型。 提前致谢。

尝试同时保存参数和模型时,pytorch会腌制参数,但仅存储模型Class的路径。 例如,更改树结构或重构可能会破坏加载。 因此,正如文档所指出的那样 ,不建议您仅使用保存/加载参数:

...序列化的数据绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。

要获得更多帮助,显示保存代码将很有用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM