繁体   English   中英

如何取两个网络的权重平均值?

[英]How to take the average of the weights of two networks?

假设在PyTorch我有model1model2具有相同的架构。 他们接受了相同数据的进一步培训,或者一个模型是othter的早期版本,但它在技术上与问题无关。 现在我想将model的权重设置为model1model2的权重的平均值。 我怎么能在PyTorch中这样做?

beta = 0.5 #The interpolation parameter    
params1 = model1.named_parameters()
params2 = model2.named_parameters()

dict_params2 = dict(params2)

for name1, param1 in params1:
    if name1 in dict_params2:
        dict_params2[name1].data.copy_(beta*param1.data + (1-beta)*dict_params2[name1].data)

model.load_state_dict(dict_params2)

取自pytorch论坛 您可以抓取参数,转换并加载它们,但要确保尺寸匹配。

此外,我真的很想知道你的发现与这些..

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM