[英]How to concatenate 2 pytorch models and make the first one non-trainable in PyTorch
I've two networks, which I need to concatenate for my full model.我有两个网络,我需要为完整的 model 连接它们。 However my first model is pre-trained and I need to make it non-trainable when training the full model.
然而,我的第一个 model 是经过预训练的,我需要在训练完整的 model 时使其不可训练。 How can I achieve this in PyTorch.
如何在 PyTorch 中实现这一点。
I am able to concatenate two models using this answer我可以使用这个答案连接两个模型
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.fc1 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyModelB(nn.Module):
def __init__(self):
super(MyModelB, self).__init__()
self.fc1 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
def forward(self, x):
x1 = self.modelA(x)
x2 = self.modelB(x1)
return x2
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)
Basically here, I want to load pre-trained modelA
and make it non-trainable when training the Ensemble model.基本上在这里,我想加载预训练的模型 A 并在训练 Ensemble
modelA
时使其不可训练。
You can freeze all parameters of the model you dont want to train, by setting requires_grad
to false.您可以通过将
requires_grad
设置为 false 来冻结您不想训练的 model 的所有参数。 Like this:像这样:
for param in model.parameters():
param.requires_grad = False
This should work for you.这应该适合你。
Another way is to handle this in your train-loop:另一种方法是在你的火车循环中处理这个:
modelA = MyModelA()
modelB = MyModelB()
criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)
for epoch in range(epochs):
for samples, targets in dataloader:
optimizerB.zero_grad()
x = modelA.train()(samples)
predictions = modelB.train()(samples)
loss = criterionB(predictions, targets)
loss.backward()
optimizerB.step()
So you pass the output of modelA to modelB but you optimize just modelB.因此,您将模型A 的 output 传递给模型B,但您只优化模型B。
One easy way to do that is to detach
the output tensor of the model that you don't want to update and it will not backprop gradient to the connected model.一种简单的方法是
detach
您不想更新的 model 的 output 张量,它不会将梯度反向传播到连接的 model。 In your case, you can simply detach x2
tensor just before concatinating with x1
in the forward function of MyEnsemble
model to keep the weight of modelB
unchanged.在您的情况下,您可以在与
MyEnsemble
model 的前向modelB
中的x1
连接之前简单地detach x2
张量,以保持模型 B 的权重不变。
So, the new forward function should be like following:因此,新的前向 function 应该如下所示:
def forward(self, x1, x2):
x1 = self.modelA(x1)
x2 = self.modelB(x2)
x = torch.cat((x1, x2.detach()), dim=1) # Detaching x2, so modelB wont be updated
x = self.classifier(F.relu(x))
return x
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.