简体   繁体   中英

PyTorch - convert ProGAN agent from pth to onnx

I trained a ProGAN agent using this PyTorch reimplementation, and I saved the agent as a .pth . Now I need to convert the agent into the .onnx format, which I am doing using this scipt:

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

device = torch.device("cuda")

dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)

torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

Once I run it, I get the error AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict' (full prompt below). As far as I understood, the problem is that converting the agent into .onnx requires more information. Am I missing something?

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
     10 state_dict = torch.load("GAN_agent.pth", map_location = device)
     11 
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    146                         operator_export_type, opset_version, _retain_param_name,
    147                         do_constant_folding, example_outputs,
--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
    149 
    150 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
     64             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
     65             example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
     67 
     68 

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
    414                                                         example_outputs, propagate,
    415                                                         _retain_param_name, do_constant_folding,
--> 416                                                         fixed_batch_size=fixed_batch_size)
    417 
    418         # TODO: Don't allocate a in-memory string for the protobuf

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
    277             model.graph, tuple(in_vars), False, propagate)
    278     else:
--> 279         graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
    280         state_dict = _unique_state_dict(model)
    281         params = list(state_dict.values())

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
    226     # A basic sanity check: make sure the state_dict keys are the same
    227     # before and after running the model.  Fail fast!
--> 228     orig_state_dict_keys = _unique_state_dict(model).keys()
    229 
    230     # By default, training=False, which is good because running a model in

~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
    283     # id(v) doesn't work with it. So we always get the Parameter or Buffer
    284     # as values, and deduplicate the params using Parameters and Buffers
--> 285     state_dict = module.state_dict(keep_vars=True)
    286     filtered_dict = type(state_dict)()
    287     seen_ids = set()

AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

Files you have there are state_dict , which are simply mappings of layer name to tensor weights biases and a-like (see here for more thorough introduction).

What that means is that you need a model so those saved weights and biases can be mapped upon, but first things first:

1. Model preparation

Clone the repository where model definitions are located and open file /pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py . We need some modifications in order for it to work with onnx . onnx exporter requires input to be passed as torch.tensor only (or list / dict of those), while Generator class needs int and float arguments).

Simple solution it to slightly modify forward function (line 80 in the file, you can verify it on GitHub ) to the following:

def forward(self, x, depth, alpha):
    """
    forward pass of the Generator
    :param x: input noise
    :param depth: current depth from where output is required
    :param alpha: value of alpha for fade-in effect
    :return: y => output
    """

    # THOSE TWO LINES WERE ADDED
    # We will pas tensors but unpack them here to `int` and `float`
    depth = depth.item()
    alpha = alpha.item()
    # THOSE TWO LINES WERE ADDED
    assert depth < self.depth, "Requested output depth cannot be produced"

    y = self.initial_block(x)

    if depth > 0:
        for block in self.layers[: depth - 1]:
            y = block(y)

        residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))
        straight = self.rgb_converters[depth](self.layers[depth - 1](y))

        out = (alpha * straight) + ((1 - alpha) * residual)

    else:
        out = self.rgb_converters[0](y)

    return out

Only unpacking via item() was added here. Every input which is not of Tensor type should be packed as one in function definition and unpacked ASAP at the top of your function. It will not destroy your created checkpoint so no worries as it's just layer-weight mapping.

2. Model exporting

Place this script in /pro_gan_pytorch (where README.md is located as well):

import torch

from pro_gan_pytorch import PRO_GAN as pg

gen = torch.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))

module = gen.module.to("cpu")

# Arguments like depth and alpha may need to be changed
dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))
torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)

Please notice a few things:

  • We have to create model before loading weights as it's a state_dict only.
  • torch.nn.DataParallel is needed as that's what the model was trained on (not sure about your case, please adjust accordingly). After loading we can get the module itself via module attribute.
  • everything is casted to CPU , no need for GPU here I think. You could cast everything to GPU if you so insist though.
  • dummy input to generator cannot be an image (I used files provided by repo authors on their Google Drive ), it has to be noise with 512 elements.

Run it and your .onnx file should be there.

Oh, and as you are after different checkpoint you may want to follow similar procedure, though no guarantees everything will work fine (it does look like it though).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM