简体   繁体   English

在 pytorch model 中获取权重和偏差并将它们复制到另一个 model 中的类似层的正确方法是什么?

[英]What is the correct way to fetch weights and biases in a pytorch model and copy those to a similar layer in another model?

I am trying to copy weights from a pretrained model layer by layer into another model of exactly similar structure.我正在尝试将权重从预训练的 model 逐层复制到另一个结构完全相同的 model 中。 The original model gives an accuracy of 94% on a binary image classification problem but the target model is unable to predict, results in predicting only one class for entire test set.原始 model 在二值图像分类问题上的准确率为 94%,但目标 model 无法预测,导致整个测试集仅预测一个 class。

For example, I used this piece of code to manually copy weights from the stem of the pretrained model to the stem of the target:例如,我使用这段代码手动将权重从预训练的 model 的词干复制到目标的词干:

modelmix.stem[0].weight = modelSep.stem[0].weight
modelmix.stem[1].weight = modelSep.stem[1].weight
modelmix.stem[1].bias = modelSep.stem[1].bias

where modelmix is the target and modelSep is the pretrained model.其中 modelmix 是目标,modelSep 是预训练的 model。 Used a similar snippet for all the other layers.对所有其他层使用了类似的片段。 The target model is not working even though I can see the weights are similar for all layers.即使我可以看到所有层的权重都相似,目标 model 也不起作用。 I am Using pytorch 1.1.我正在使用 pytorch 1.1。 Thank you谢谢

You can create another model that the parameters names are equal for example:您可以创建另一个参数名称相同的 model,例如:

import torch.nn as nn

model1 = nn.Sequential()
model1.add_module('layer1', nn.Linear(10, 20))
model1.add_module('layer2', nn.Linear(20, 10))

model2 = nn.Sequential()
model2.add_module('layer1', nn.Linear(10, 20))
model2.add_module('layer2', nn.Linear(20, 10))
model2.add_module('layer3', nn.Linear(10, 5))

Then you can load the model1 state_dict to model2, and viceversa with the kwargs strict=False .然后您可以将 model1 state_dict 加载到 model2,反之亦然,使用 kwargs strict=False

model2.load_state_dict(model1.state_dict(), strict=False)

If you want something more custom, you should go as you mention.如果您想要更定制的东西,您应该提到 go。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM