[英]Error in loading model in PyTorch
我有以下代碼片段
from train import predict
import random
import torch
ann=torch.load('ann.pt') #importing trained model
while True:
k=raw_input("User:")
intent,top_value,top_index = predict(str(k),ann)
print(intent)
當我運行腳本時,它會引發如下錯誤:
Traceback (most recent call last):
File "test.py", line 6, in <module>
ann=torch.load('ann.pt') #importing trained model
File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
return _load(f, map_location, pickle_module)
File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'
我的腳本位於同一文件夾中的ann.pt文件。 請幫助我確定修復錯誤並加載模型。 提前致謝。
嘗試同時保存參數和模型時,pytorch會腌制參數,但僅存儲模型Class的路徑。 例如,更改樹結構或重構可能會破壞加載。 因此,正如文檔所指出的那樣 ,不建議您僅使用保存/加載參數:
...序列化的數據綁定到所使用的特定類和確切的目錄結構,因此在其他項目中使用時或經過一些嚴重的重構后,它可能以各種方式中斷。
要獲得更多幫助,顯示保存代碼將很有用。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.