简体   繁体   English

将Pytorch .ckpt(GAN)模型转换为onnx

[英]Converting Pytorch .ckpt(GAN) model into onnx

I have one pre-trained GAN model into format of .ckpt extension. 我有一个经过预训练的GAN模型,扩展名为.ckpt扩展名。 I want to convert that into onnx model. 我想将其转换为onnx模型。 But I am not finding any way to do that. 但是我没有找到任何办法。

I trained 10 categories in RaFD mode with https://github.com/yunjey/StarGAN and get pre-trained model , so I convert this model to onnx. 我使用https://github.com/yunjey/StarGAN在RaFD模式下训练了10个类别,并获得了预先训练的模型,因此我将此模型转换为onnx。

Then, I got the error below. 然后,我得到下面的错误。

I can not solve this error without knowing it. 我无法不知道就解决这个错误。 please tell me. 请告诉我。

there is my code. 有我的代码。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict


class ResidualBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

    def forward(self, x):
        return x + self.main(x)


class Generator(nn.Module):

    def __init__(self, conv_dim=64, c_dim=10, repeat_num=6):
        super(Generator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7,
                                stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
        layers.append(nn.ReLU(inplace=True))

        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2,
                                    kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim //
                                             2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7,
                                stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat([x, c], dim=1)
        return self.main(x)


model = Generator().cuda()
state_dict = torch.load('../models/300000-G.ckpt')
model.load_state_dict(state_dict, strict=False)
dummy_input = Variable(torch.randn(32, 3, 256, 256)).cuda()


dummy_input = Variable(torch.randn(32, 3, 256, 256)).cuda()
torch.onnx.export(model, , 'model.onnx', verbose=False)

there is my error: 有我的错误:

Traceback (most recent call last):
  File "test.py", line 85, in <module>
    torch.onnx.export(model, x, 'model.onnx', verbose=False)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/onnx/__init__.py", line 25, in export
    return utils.export(*args, **kwargs)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/onnx/utils.py", line 84, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/onnx/utils.py", line 134, in _export
    trace, torch_out = torch.jit.get_trace_graph(model, args)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/jit/__init__.py", line 255, in get_trace_graph
    return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/jit/__init__.py", line 288, in forward
    out = self.inner(*trace_inputs)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "~/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 479, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'c'

You have not supplied dummy_input in export function. 您尚未在导出功能中提供dummy_input。 It should be 它应该是

torch.onnx.export(model,dummy_input , 'model.onnx', verbose=False)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM