簡體   English   中英

PyTorch中的圖像特征提取

[英]Image Feature Extraction in PyTorch

我很難理解這個代碼片段。

import torch
import torch.nn as nn
import torchvision.models as models

def ResNet152(out_features = 10):
      return getattr(models, "resnet152")(pretrained=False, num_classes = out_features)

def VGG(out_features = 10):
      return getattr(models, "vgg19")(pretrained=False, num_classes = out_features)

在此代碼段中,輸入圖像的特征由 ResNet152 和 Vgg19 model 提取。 但我有一個問題,是從這些模型的哪個部分提取特征,無論該部分是最后一個池化層還是分類層之前的層或其他什么。

請注意, getattr(models, 'resnet152')等效於models.resent152

因此,下面的代碼返回 model 本身。

getattr(models, "resnet152")(pretrained=False, num_classes = out_features)
# is same as
models.resnet152(pretrained=False, num_classes = out_features)

現在,如果您通過簡單的打印來查看 model 的結構,最后一層是全連接層,所以這就是您在這里得到的特征。

print(ResNet152())

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=10, bias=True)
)

VGG()也是如此。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM