簡體   English   中英

如何在 PyTorch 中獲取神經網絡的編碼器部分?

[英]How to get the encoder part of a neural network in PyTorch?

我想應用轉移學習(使用來自 UNet 或 ResNet 的預訓練編碼器的權重啟動我的自定義網絡的編碼器)。 所以問題是:給定 Pytorch 中的 UNet 或 ResNet 實例,如何在 PyTorch 中提取 ResNet 或 UNet 的編碼器部分?

這個博客展示了一種方法,但它首先要求我擁有 UNet 或 ResNet 類,這對我來說不切實際。 因為 UNet 或 ResNet 的實例是通過如下函數獲取的: net = get_resnet(depth=34) ,所以我只能獲取net = get_resnet(depth=34)或 ResNet 的實例,但無法獲取它們的類。

對於 ResNet,最后一層只是self.fc ,因此如果您實例化 ResNet 模型,您可以將self.fc重新定義為您首選的分類任務,保持模型的其余部分完好無損,包括預訓練的權重(如果適用)。

對於 UNet,這有點棘手,因為解碼器由上采樣層和輸出層組成,但同樣可以替換self.up1self.up2self.up3self.up4self.outc .

請務必在加載權重后更換圖層。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM