简体   繁体   中英

Error in loading model in PyTorch

I Have the following code snippet

from train import predict
import random
import torch


ann=torch.load('ann.pt') #importing trained model


while True:
      k=raw_input("User:")
      intent,top_value,top_index = predict(str(k),ann)
      print(intent)

when I run the script it is throwing the error as below:

Traceback (most recent call last):
  File "test.py", line 6, in <module>
    ann=torch.load('ann.pt') #importing trained model
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 261, in load
    return _load(f, map_location, pickle_module)
  File "/home/local/ZOHOCORP/raghav-5305/miniconda2/lib/python2.7/site-packages/torch/serialization.py", line 409, in _load
    result = unpickler.load()
AttributeError: 'module' object has no attribute 'ANN'

I have ann.pt file in the same folder as my script is. Kindly help me identify fix the error and load the model. Thanks in advance.

When trying to save both parameters and model, pytorch pickles the parameters but only store path the model Class. For instance, changing tree structure or refactoring can break loading. Therefore as the documentation points out , it is not recommended, prefer only save/load parameters:

...the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.

For more help, it'll be useful to show your saving code.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM