簡體   English   中英

我要更改哪些參數才能從頭開始訓練 pytorch model?

[英]What parameters do I change to train a pytorch model from scratch?

I followed this tutorial to train a pytorch model for instance segmentation: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

我不想在完全不同的數據和類上訓練 model,與 COCO 完全無關。 我需要進行哪些更改才能重新訓練 model。 根據我的閱讀,我猜除了有正確數量的課程外,我只需要訓練這條線:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

但我注意到還有另一個參數: pretrained_backbone=True, trainable_backbone_layers=None他們也應該改變嗎?

function 簽名是

torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs)

設置pretrained=False將告訴 PyTorch 不要下載在 COCO train2017 上預訓練的 model。 您想要它,因為您對培訓感興趣。

通常,如果您想在不同的數據集上進行訓練,這就足夠了。

當您設置pretrained=False時,PyTorch 將在 ImageNet 上下載預訓練的 ResNet50。 默認情況下,它會凍結名為conv1layer1的前兩個塊。 這就是 Faster R-CNN 論文中的做法,該論文凍結了預訓練主干的初始層。

(只需打印 model 即可檢查其結構)。

layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]

現在,如果您甚至不希望前兩層凍結,您可以設置 trainable_backbone_layers trainable_backbone_layers=5 (當您設置pretrained_backbone=False時自動完成),這將從頭開始訓練整個 resnet 主干。

檢查PR#2160

來自maskrcnn_resnet50_fpn 文檔

  • pretrained (bool) – 如果為真,則返回在 COCO train2017 上預訓練的 model
  • pretrained_backbone (bool) - 如果為真,返回 model,主干在 Imagenet 上預訓練
  • trainable_backbone_layers (int) – 從最終塊開始的可訓練(未凍結)resnet 層數。 有效值介於 0 和 5 之間,其中 5 表示所有主干層都是可訓練的。

因此,使用以下方法從頭開始訓練:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, trainable_backbone_layers=5, num_classes=your_num_classes)

或者:

model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, num_classes=your_num_classes)

因為在maskrcnn_resnet50_fpn的源代碼中:

if not (pretrained or pretrained_backbone):
    trainable_backbone_layers = 5

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM