![](/img/trans.png)
[英]Is there a tensorflow way to extract/save mean and std used for normalization?
[英]Is there any way to copy all parameters of one Pytorch model to another specially Batch Normalization mean and std?
我在網上找到了許多正確的方法來將一個pytorch模型參數復制到另一個參數,但是以某種方式進行復制粘貼操作始終會錯過批標准化參數。 只要我在模型中僅使用諸如conv2d,linear,drop,max pool等模塊,一切都可以正常工作。 但是,一旦我在pytorch模型中添加了批處理規范化,下面給出的腳本就會停止工作,並且測試時的准確性有所不同:
net = model()
copy_net = model()
for param in net.module.parameters():
copy_param.append(param.clone().detach())
count = 0
for param in copy_net.module.parameters():
param.data = copy_param[count]
param.requires_grad = False
count = count +1
有人可以給我復制批處理規范化的可能解決方案嗎?
net.load_state_dict(copy_net.state_dict())
應該可以工作。
按照@dxtx,按照pytorch的哲學,狀態dict應該涵蓋“模塊”中的所有狀態,例如在批處理規范模塊中,如果我沒有記錯的話,運行均值和var應該是狀態dict的一部分。 但是實際上,如果您自己編寫了類似批處理規范的模塊,則必須重寫'state_dict'方法。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.