簡體   English   中英

有什么方法可以將一個Pytorch模型的所有參數復制到另一個專門的“批量歸一化”平均值和std?

[英]Is there any way to copy all parameters of one Pytorch model to another specially Batch Normalization mean and std?

我在網上找到了許多正確的方法來將一個pytorch模型參數復制到另一個參數,但是以某種方式進行復制粘貼操作始終會錯過批標准化參數。 只要我在模型中僅使用諸如conv2d,linear,drop,max pool等模塊,一切都可以正常工作。 但是,一旦我在pytorch模型中添加了批處理規范化,下面給出的腳本就會停止工作,並且測試時的准確性有所不同:

net = model()
copy_net = model()

for param in net.module.parameters():
    copy_param.append(param.clone().detach())

count = 0
for param in copy_net.module.parameters():
    param.data =  copy_param[count]
    param.requires_grad = False
    count = count +1

有人可以給我復制批處理規范化的可能解決方案嗎?

net.load_state_dict(copy_net.state_dict())應該可以工作。

按照@dxtx,按照pytorch的哲學,狀態dict應該涵蓋“模塊”中的所有狀態,例如在批處理規范模塊中,如果我沒有記錯的話,運行均值和var應該是狀態dict的一部分。 但是實際上,如果您自己編寫了類似批處理規范的模塊,則必須重寫'state_dict'方法。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM