簡體   English   中英

在PyTorch中加載模型時出錯

[英]Error in loading model in PyTorch

我有以下代碼片段

from train import predict
import random
import torch


ann=torch.load('ann.pt') #importing trained model


while True:
      k=raw_input("User:")
      intent,top_value,top_index = predict(str(k),ann)
      print(intent)

當我運行腳本時,它會引發如下錯誤:

Traceback (most recent call last):
  File "test.py", line 6, in <module>
    ann=torch.load('ann.pt') #importing trained model
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
    return _load(f, map_location, pickle_module)
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
    result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'

我的腳本位於同一文件夾中的ann.pt文件。 請幫助我確定修復錯誤並加載模型。 提前致謝。

嘗試同時保存參數和模型時,pytorch會腌制參數,但僅存儲模型Class的路徑。 例如,更改樹結構或重構可能會破壞加載。 因此,正如文檔所指出的那樣 ,不建議您僅使用保存/加載參數:

...序列化的數據綁定到所使用的特定類和確切的目錄結構,因此在其他項目中使用時或經過一些嚴重的重構后,它可能以各種方式中斷。

要獲得更多幫助,顯示保存代碼將很有用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM