簡體   English   中英

如何取兩個網絡的權重平均值?

[英]How to take the average of the weights of two networks?

假設在PyTorch我有model1model2具有相同的架構。 他們接受了相同數據的進一步培訓,或者一個模型是othter的早期版本,但它在技術上與問題無關。 現在我想將model的權重設置為model1model2的權重的平均值。 我怎么能在PyTorch中這樣做?

beta = 0.5 #The interpolation parameter    
params1 = model1.named_parameters()
params2 = model2.named_parameters()

dict_params2 = dict(params2)

for name1, param1 in params1:
    if name1 in dict_params2:
        dict_params2[name1].data.copy_(beta*param1.data + (1-beta)*dict_params2[name1].data)

model.load_state_dict(dict_params2)

取自pytorch論壇 您可以抓取參數,轉換並加載它們,但要確保尺寸匹配。

此外,我真的很想知道你的發現與這些..

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM