简体   繁体   English

如何在 pytorch 脚本 model 中获取运算符的输入大小?

[英]How to get input size for a operator in pytorch script model?

I use this code to transfer the model to script model:我使用此代码将 model 传输到脚本 model:

scripted_model = torch.jit.trace(detector.model, images).eval()

Then I print the scripted_model.然后我打印 scripted_model。 A part of the output is as follows: output部分内容如下:

 (base): DLA(
    original_name=DLA
    (base_layer): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level0): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level1): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level2): Tree(
      original_name=Tree
      (tree1): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d)
      )
      (tree2): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d
      )
      (root): Root(
        original_name=Root
        (conv): Conv2d(original_name=Conv2d)
        (bn): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
      )
      (downsample): MaxPool2d(original_name=MaxPool2d)
      (project): Sequential(
        original_name=Sequential
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
    ) 
...

I just want to get the input size for the operator, such as how many inputs for the operator (0): Conv2d(original_name=Conv2d) .我只想获取运算符的输入大小,例如运算符(0): Conv2d(original_name=Conv2d) I print the graph of this script model, the output is as follows:我打印这个脚本model的图形,output如下:

  %4770 : __torch__.torch.nn.modules.module.___torch_mangle_11.Module = prim::GetAttr[name="wh"](%self.1)
  %4762 : __torch__.torch.nn.modules.module.___torch_mangle_15.Module = prim::GetAttr[name="tracking"](%self.1)
  %4754 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="rot"](%self.1)
  %4746 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="reg"](%self.1)
  %4738 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="hm"](%self.1)
  %4730 : __torch__.torch.nn.modules.module.___torch_mangle_27.Module = prim::GetAttr[name="dim"](%self.1)
  %4722 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="dep"](%self.1)
  %4714 : __torch__.torch.nn.modules.module.___torch_mangle_31.Module = prim::GetAttr[name="amodel_offset"](%self.1)
  %4706 : __torch__.torch.nn.modules.module.___torch_mangle_289.Module = prim::GetAttr[name="ida_up"](%self.1)
  %4645 : __torch__.torch.nn.modules.module.___torch_mangle_262.Module = prim::GetAttr[name="dla_up"](%self.1)
  %4461 : __torch__.torch.nn.modules.module.___torch_mangle_180.Module = prim::GetAttr[name="base"](%self.1)
  %5100 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4461, %input.1)
  %5082 : Tensor, %5083 : Tensor, %5084 : Tensor, %5085 : Tensor, %5086 : Tensor, %5087 : Tensor, %5088 : Tensor, %5089 : Tensor = prim::TupleUnpack(%5100)
  %5101 : (Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4645, %5082, %5083, %5084, %5085, %5086, %5087, %5088, %5089)
  %5097 : Tensor, %5098 : Tensor, %5099 : Tensor = prim::TupleUnpack(%5101)
  %3158 : None = prim::Constant()

I even can find the operator name.我什至可以找到运营商名称。 How can I get input size for a specific operator in the script model?如何在脚本 model 中获取特定运算符的输入大小?

One solution is to try summary from torchinfo and the output shape of the first layer is the input shape for the next one and so on:一个解决方案是尝试torchinfosummary ,第一层的 output 形状是下一层的输入形状,依此类推:

!pip install torchinfo

from torchinfo import summary
summary(model, input_size=(batch_size, 3, 224, 224)) # input size to your NN

#output 

===============================================================================================
    Layer (type:depth-idx)                        Output Shape              Param #
    ===============================================================================================
    ResNet50                                      --                        --
    ├─ResNet: 1-1                                 [64, 10]                  --
    │    └─Conv2d: 2-1                            [64, 64, 112, 112]        9,408
    │    └─BatchNorm2d: 2-2                       [64, 64, 112, 112]        128
    │    └─ReLU: 2-3                              [64, 64, 112, 112]        --
    │    └─MaxPool2d: 2-4                         [64, 64, 56, 56]          --
    │    └─Sequential: 2-5                        [64, 64, 56, 56]          --
    │    │    └─BasicBlock: 3-1                   [64, 64, 56, 56]          73,984
    │    │    └─BasicBlock: 3-2                   [64, 64, 56, 56]          73,984
    │    └─Sequential: 2-6                        [64, 128, 28, 28]         --
    │    │    └─BasicBlock: 3-3                   [64, 128, 28, 28]         230,144
    │    │    └─BasicBlock: 3-4                   [64, 128, 28, 28]         295,424
    │    └─Sequential: 2-7                        [64, 256, 14, 14]         --
    │    │    └─BasicBlock: 3-5                   [64, 256, 14, 14]         919,040
    │    │    └─BasicBlock: 3-6                   [64, 256, 14, 14]         1,180,672
    │    └─Sequential: 2-8                        [64, 512, 7, 7]           --
    │    │    └─BasicBlock: 3-7                   [64, 512, 7, 7]           3,673,088
    │    │    └─BasicBlock: 3-8                   [64, 512, 7, 7]           4,720,640
    │    └─AdaptiveAvgPool2d: 2-9                 [64, 512, 1, 1]           --
    │    └─Linear: 2-10                           [64, 10]                  5,130
    ===============================================================================================
    Total params: 11,181,642
    Trainable params: 11,181,642
    Non-trainable params: 0
    Total mult-adds (G): 116.07
    ===============================================================================================
    Input size (MB): 38.54
    Forward/backward pass size (MB): 2543.33
    Params size (MB): 44.73
    Estimated Total Size (MB): 2626.59
    ===============================================================================================

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM