[英]How can I add new layers on pre-trained model with PyTorch? (Keras example given.)
[英]Is there any way I can download the pre-trained models available in PyTorch to a specific path?
因為, @dennlinger在他的回答中提到: torch.utils.model_zoo
,在您加載預訓練模型時在內部調用。
更具體地說,每次加載預訓練模型時都會調用方法: torch.utils.model_zoo.load_url()
。 相同的文檔提到:
model_dir
的默認值為$TORCH_HOME/models
,其中$TORCH_HOME
默認為~/.torch
。可以使用
$TORCH_HOME
環境變量覆蓋默認目錄。
這可以按如下方式完成:
import torch
import torchvision
import os
# Suppose you are trying to load pre-trained resnet model in directory- models\resnet
os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)
我通過在 PyTorch 的 GitHub 存儲庫中提出問題來遇到上述解決方案: https : //github.com/pytorch/vision/issues/616
這導致了文檔的改進,即上面提到的解決方案。
是的,您可以簡單地復制網址並使用wget
將其下載到所需的路徑。 這是一個插圖:
對於AlexNet :
$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
對於Google Inception (v3) :
$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
對於SqueezeNet :
$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth
對於MobileNetV2 :
$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
對於DenseNet201 :
$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth
對於MNASNet1_0 :
$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth
對於ShuffleNetv2_x1.0 :
$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
如果您想在 Python 中執行此操作,請使用以下內容:
In [11]: from six.moves import urllib
# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)
PS:下載地址可以在torchvision.models的python模塊中找到
有一個可用的腳本 output 整個 package 的 URL 列表。
在pytorch/vision
package 中執行以下命令:
python scripts/collect_model_urls.py .
# ...
# https://download.pytorch.org/models/swin_v2_b-781e5279.pth
# https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth
# https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth
# https://download.pytorch.org/models/vgg11-8a719046.pth
# https://download.pytorch.org/models/vgg11_bn-6002323d.pth
# ...
TL;DR:不,不能直接使用,但您可以輕松適應。
我認為您想要做的是查看torch.utils.model_zoo
,它在您加載預訓練模型時在內部調用:
如果我們查看預訓練模型的代碼,例如這里的AlexNet,我們可以看到它只是調用了前面提到的model_zoo
函數,但沒有保存位置。 您可以修改 PyTorch 源以指定這一點(這實際上是一個很好的補充 IMO,因此可能為此打開一個拉取請求),或者只是根據自己的喜好采用第二個鏈接中的代碼(並將其保存到以不同的名稱自定義位置),然后在那里手動插入相關位置。
如果您想定期更新 PyTorch,我強烈推薦第二種方法,因為它不涉及直接更改 PyTorch 的代碼庫,並且可能會在更新期間引發錯誤。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.