[英]Runtime Error when converting Pytorch model to PyTorch jit script
我正在尝试制作一个简单的 PyTorch 模型并使用以下代码将其转换为 PyTorch jit 脚本。 (最终目标是将其转换为 PyTorch Mobile)
class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, x):
return torch.cat(x,1)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
z = self.conv1(x)
z = self.conv2(z)
return (y, z)
net = nn.Sequential(
Net(),
Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)
但是上面的代码抛出以下错误。
RuntimeError Traceback (most recent call last)
Cell In [2], line 26
21 net = nn.Sequential(
22 Net(),
23 Concat()
24 )
25 mobile_net = torch.quantization.convert(net)
---> 26 scripted_net = torch.jit.script(mobile_net)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:1286, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1284 if isinstance(obj, torch.nn.Module):
1285 obj = call_prepare_scriptable_func(obj)
-> 1286 return torch.jit._recursive.create_script_module(
1287 obj, torch.jit._recursive.infer_methods_to_compile
1288 )
1290 if isinstance(obj, dict):
1291 return create_script_dict(obj)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:476, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
474 if not is_tracing:
475 AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
535 script_module._concrete_type = concrete_type
537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
540 # Compile methods if necessary
541 if concrete_type not in concrete_type_store.methods_compiled:
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
602 """
603 Construct a RecursiveScriptModule that's ready for use. PyTorch
604 code should use this to construct a RecursiveScriptModule instead
(...)
612 init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
613 """
614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
617 # Finalize the ScriptModule: replace the nn.Module state with our
618 # custom implementations and flip the _initializing bit.
619 RecursiveScriptModule._finalize_scriptmodule(script_module)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
513 scripted = orig_value
514 else:
515 # always reuse the provided stubs_fn to infer the methods to compile
--> 516 scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
518 cpp_module.setattr(name, scripted)
519 script_module._modules[name] = scripted
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
540 # Compile methods if necessary
541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
543 # Create hooks after methods to ensure no name collisions between hooks and methods.
544 # If done before, hooks can overshadow methods that aren't exported.
545 create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
File ~\anaconda3\envs\yolov5pytorch\lib\site-packages\torch\jit\_recursive.py:393, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
390 property_defs = [p.def_ for p in property_stubs]
391 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'tensors' to be of type 'Tensor' because it was not annotated with an explicit type.
The original call is:
File "C:\Users\pawan\AppData\Local\Temp\ipykernel_16484\3929675973.py", line 6
def forward(self, x):
return torch.cat(x,1)
~~~~~~~~~ <--- HERE
我是 PyTorch 的新手,不熟悉 PyTorch 的内部工作,请提供解决方案。 如果 torch.cat 结合在 Net 类的 forward 方法中,即我们返回 torch.cat((y, z),1) 而不是 return (y, z),那么它可以工作,但我想使用不同的类来完成用于串联。
为什么会发生错误
在编译Concat.forward
时, torch.jit
假定参数x
是一个Tensor
。 层, torch.jit
意识到传递给Concat.forward
的实际参数是一个元组(y, z)
,因此torch.jit
得出结论“Arguments for call are not valid”(因为元组不是Tensor
)。
如何修复
在Concat.forward
中明确指定参数x
的类型为Tuple[torch.Tensor, torch.Tensor]
,这样torch.jit
知道你想要什么了。
from typing import Tuple
class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
# ^^^ torch.jit.script needs this ^^^
return torch.cat(x,1)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1)
self.conv2 = nn.Conv2d(16, 32, 3, 1)
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
z = self.conv1(x)
z = self.conv2(z)
return (y, z)
net = nn.Sequential(
Net(),
Concat()
)
mobile_net = torch.quantization.convert(net)
scripted_net = torch.jit.script(mobile_net)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.