简体   繁体   English

修剪模型后删除 Pytorch 中的 weight_orig

[英]Remove the weight_orig in Pytorch after Pruning a model

After a model is pruned in Pytorch, the saved model contains both the pruned weights and weight_orig.在 Pytorch 中修剪模型后,保存的模型包含修剪后的权重和 weight_orig。 This causes the pruned model size to be greater than the unpruned model.这会导致修剪后的模型大小大于未修剪的模型。 Is there a way to remove the weight_orig and reduce the pruned model size?有没有办法删除 weight_orig 并减少修剪后的模型大小?

As explained in the offcial documentation , you can use torch.nn.utils.prune.remove() for this very purpose.正如官方文档中所述,您可以为此目的使用torch.nn.utils.prune.remove()
remove() removes the re-parametrization in terms of weight_orig and weight_mask , and removes the forward_pre_hook . remove()根据weight_origweight_mask删除重新参数化,并删除forward_pre_hook You'd use it like this:你会像这样使用它:

for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module,'weight')
    # etc...

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM