繁体   English   中英

将 Pytorch 模型转换为 PyTorch jit 脚本时出现运行时错误

[英]Runtime Error when converting Pytorch model to PyTorch jit script

我正在尝试制作一个简单的 PyTorch 模型并使用以下代码将其转换为 PyTorch jit 脚本。 (最终目标是将其转换为 PyTorch Mobile)

class Concat(nn.Module):
    def __init__(self):
        super(Concat, self).__init__()

    def forward(self, x):
        return torch.cat(x,1)

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        z = self.conv1(x)
        z = self.conv2(z)
        return (y, z)

net = nn.Sequential(
    Net(),
    Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)

但是上面的代码抛出以下错误。

RuntimeError                              Traceback (most recent call last)
Cell In [2], line 26
     21 net = nn.Sequential(
     22     Net(),
     23     Concat()
     24 )
     25 mobile_net = torch.quantization.convert(net)
---> 26 scripted_net = torch.jit.script(mobile_net)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:1286, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1284 if isinstance(obj, torch.nn.Module):
   1285     obj = call_prepare_scriptable_func(obj)
-> 1286     return torch.jit._recursive.create_script_module(
   1287         obj, torch.jit._recursive.infer_methods_to_compile
   1288     )
   1290 if isinstance(obj, dict):
   1291     return create_script_dict(obj)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:476, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    474 if not is_tracing:
    475     AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    535     script_module._concrete_type = concrete_type
    537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
    602 """
    603 Construct a RecursiveScriptModule that's ready for use. PyTorch
    604 code should use this to construct a RecursiveScriptModule instead
   (...)
    612     init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
    613 """
    614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
    617 # Finalize the ScriptModule: replace the nn.Module state with our
    618 # custom implementations and flip the _initializing bit.
    619 RecursiveScriptModule._finalize_scriptmodule(script_module)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
    513     scripted = orig_value
    514 else:
    515     # always reuse the provided stubs_fn to infer the methods to compile
--> 516     scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    518 cpp_module.setattr(name, scripted)
    519 script_module._modules[name] = scripted

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    543     # Create hooks after methods to ensure no name collisions between hooks and methods.
    544     # If done before, hooks can overshadow methods that aren't exported.
    545     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:393, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    390 property_defs = [p.def_ for p in property_stubs]
    391 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
  
  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
  Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.

The original call is:
  File "C:\Users\pawan\AppData\Local\Temp\ipykernel_16484\3929675973.py", line 6
    def forward(self, x):
        return torch.cat(x,1)
               ~~~~~~~~~ <--- HERE

我是 PyTorch 的新手,不熟悉 PyTorch 的内部工作,请提供解决方案。 如果 torch.cat 结合在 Net 类的 forward 方法中,即我们返回 torch.cat((y, z),1) 而不是 return (y, z),那么它可以工作,但我想使用不同的类来完成用于串联。

为什么会发生错误

在编译Concat.forward时, torch.jit假定参数x是一个Tensor 层, torch.jit意识到传递给Concat.forward的实际参数是一个元组(y, z) ,因此torch.jit得出结论“Arguments for call are not valid”(因为元组不是Tensor )。

如何修复

Concat.forward明确指定参数x的类型为Tuple[torch.Tensor, torch.Tensor] ,这样torch.jit知道你想要什么了。

from typing import Tuple

class Concat(nn.Module):
    def __init__(self):
        super(Concat, self).__init__()

    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
        #              ^^^ torch.jit.script needs this ^^^
        return torch.cat(x,1)

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        z = self.conv1(x)
        z = self.conv2(z)
        return (y, z)

net = nn.Sequential(
    Net(),
    Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM