簡體   English   中英

有什么方法可以將PyTorch中可用的預訓練模型下載到特定路徑嗎?

[英]Is there any way I can download the pre-trained models available in PyTorch to a specific path?

我指的是可以在這里找到的模型: https://pytorch.org/docs/stable/torchvision/models.html#torchvision-models

因為, @dennlinger在他的回答中提到: torch.utils.model_zoo ,在您加載預訓練模型時在內部調用。

更具體地說,每次加載預訓練模型時都會調用方法: torch.utils.model_zoo.load_url() 相同的文檔提到:

model_dir的默認值為$TORCH_HOME/models ,其中$TORCH_HOME默認為~/.torch

可以使用$TORCH_HOME環境變量覆蓋默認目錄。

這可以按如下方式完成:

import torch 
import torchvision
import os

# Suppose you are trying to load pre-trained resnet model in directory- models\resnet

os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)

我通過在 PyTorch 的 GitHub 存儲庫中提出問題來遇到上述解決方案: https : //github.com/pytorch/vision/issues/616

這導致了文檔的改進,即上面提到的解決方案。

是的,您可以簡單地復制網址並使用wget將其下載到所需的路徑。 這是一個插圖:

對於AlexNet

$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth

對於Google Inception (v3)

$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth

對於SqueezeNet

$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth

對於MobileNetV2

$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth

對於DenseNet201

$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth

對於MNASNet1_0

$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth

對於ShuffleNetv2_x1.0

$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

如果您想在 Python 中執行此操作,請使用以下內容:

In [11]: from six.moves import urllib

# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"

# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)

PS:下載地址可以在torchvision.models的python模塊中找到

有一個可用的腳本 output 整個 package 的 URL 列表。

pytorch/vision package 中執行以下命令:

python scripts/collect_model_urls.py .

# ...
# https://download.pytorch.org/models/swin_v2_b-781e5279.pth
# https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth
# https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth
# https://download.pytorch.org/models/vgg11-8a719046.pth
# https://download.pytorch.org/models/vgg11_bn-6002323d.pth
# ...

TL;DR:不,不能直接使用,但您可以輕松適應。

我認為您想要做的是查看torch.utils.model_zoo ,它在您加載預訓練模型時在內部調用:

如果我們查看預訓練模型的代碼,例如這里的AlexNet,我們可以看到它只是調用了前面提到的model_zoo函數,但沒有保存位置。 您可以修改 PyTorch 源以指定這一點(這實際上是一個很好的補充 IMO,因此可能為此打開一個拉取請求),或者只是根據自己的喜好采用第二個鏈接中的代碼(並將其保存到以不同的名稱自定義位置),然后在那里手動插入相關位置。

如果您想定期更新 PyTorch,我強烈推薦第二種方法,因為它不涉及直接更改 PyTorch 的代碼庫,並且可能會在更新期間引發錯誤。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM