I was playing around with the function torch.save
and I noticed something curious, let's say i load a model from torchvision
repository:
model = torchvision.models.mobilenet_v2()
if i save the model in this way:
torch.save(model,'model.pth')
I get a 14MB file, while if i do:
torch.save(model.state_dict(),'state_dict.pth')
The file size blow to ~500MB. Since i didn't find any reference on this behaviour I was wondering what does cause the increment in size. Is it something related to compression? does saving the whole state_dict
stores extra stuff like uninitialized gradients?
PS the same happens for other model like vgg16
Hey I deleted my last answer since I was wrong, as of pytorch version 1.1.0a0+863818e
, using:
torch.save(model,'model.pth')
And using:
torch.save(model.state_dict(),'state_dict.pth')
Gave the same results on size for both, are you using you are loading the nets correctly? Proof:
-rw-rw-r-- 1 bpinaya bpinaya 14M Aug 8 10:26 model.pth
-rw-rw-r-- 1 bpinaya bpinaya 14M Aug 8 10:27 state_dict.pth
-rw-rw-r-- 1 bpinaya bpinaya 528M Aug 8 10:29 vgg.pth
-rw-rw-r-- 1 bpinaya bpinaya 528M Aug 8 10:29 vggstate_dict.pth
If ask for what is in model:
vars(vgg16)
Out:
{'_backend': <torch.nn.backends.thnn.THNNFunctionBackend at 0x232c78759b0>,
'_parameters': OrderedDict(),
'_buffers': OrderedDict(),
'_backward_hooks': OrderedDict(),
'_forward_hooks': OrderedDict(),
'_forward_pre_hooks': OrderedDict(),
'_state_dict_hooks': OrderedDict(),
'_load_state_dict_pre_hooks': OrderedDict(),
'_modules': OrderedDict([('features', Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)),
('avgpool', AdaptiveAvgPool2d(output_size=(7, 7))),
('classifier', Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
))]),
'training': True}
You will get it is more than just state dict.
vgg16.state_dict()
State dict is inside _modules
( vgg16._modules['features'].state_dict()
)
This is why when you save the model you save not just the state dict, but also all aforementioned stuff such as parameters, buffers, hooks...
But if you don't use parameters, buffers, hooks for inference time for your model you may avoid saving these.
The sizes when saving:
torch.save(model,'model.pth')
torch.save(model.state_dict(),'state_dict.pth')
should be: model.pth > state_dict.pth
because state dict is included into the model.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.