简体   繁体   中英

How do I load a local model with torch.hub.load?

I need to avoid downloading the model from the web (due to restrictions on the machine installed).

This works, but downloads the model from the net

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)

I have placed the .pth file and the hubconf.py file in the /tmp/ folder and changed my code to

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True,s ource='local')

but to my surprise it still downloads the model from the internet. What am I doing wrong? How can I load the model locally.

Just to give you a bit more details, I'm doing all this in a Docker container which has a read-only volume at runtime, so that's why the download of new files fails.

thanks,

John

There's two approaches you can take to get shippable model on machine w/o internet.

1.Load DeepLab with pretrained model on normal machine, use jit compiler to export it as a graph, and put it into the machine. Script is easy to follow:

# To export
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
traced_graph.save('DeepLab.pth')

# To load
model = torch.jit.load('DeepLab.pth').eval().to(device)

In this case, the weights and network structure is saved as computational graph, so you won't need any extra files.

  1. Take a look at torchvision's github repo .

There's a download url for DeepLabV3 w/ Resnet101 backbone weights. You can download those weights once, then use deeplab from torchvision with pretrained=False flag and load weights manually.

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
model.load_state_dict(torch.load('downloaded weights path'))

Take in consideration, there might be a ['state_dict'] or some similar parent key in state dict, where you would use:

model.load_state_dict(torch.load('downloaded weights path')['state_dict'])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM