简体   繁体   中英

Converting Pytorch model .pth into onnx model

I have one pre-trained model into format of .pth extension. I want to convert that into Tensorflow protobuf. But I am not finding any way to do that. I have seen onnx can convert models from pytorch into onnx and then from onnx to Tensorflow. But with that approach I got following error in the first stage of conversion.

from torch.autograd import Variable
import torch.onnx
import torchvision
import torch 

dummy_input = Variable(torch.randn(1, 3, 256, 256))
model = torch.load('./my_model.pth')
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")`

It gives error like this.

File "t.py", line 9, in <module>
    torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 75, in export
    _export(model, args, f, export_params, verbose, training)
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 108, in _export
    orig_state_dict_keys = model.state_dict().keys()
AttributeError: 'dict' object has no attribute 'state_dict'

What is possible solution ?

try changing your code to this

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 256, 256))
state_dict = torch.load('./my_model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")

That means your model is not a subclass of the torch.nn.Modules class. If you make it a subclass, this should work.

The problem here is you are loading the weights of the model, But you need the architechture of your model here as well, for example if you are using mobilenet:

import torch
import torchvision.models as models

model=models.mobilenet_v3_large(weights)#Give your weights here
torch.onnx.export(model, torch.rand(1,3,640,640), "MobilenetV3.onnx")

For more refer this : https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM