简体   繁体   English

如何连接 2 个 pytorch 模型并使第一个模型在 PyTorch 中不可训练

[英]How to concatenate 2 pytorch models and make the first one non-trainable in PyTorch

I've two networks, which I need to concatenate for my full model.我有两个网络,我需要为完整的 model 连接它们。 However my first model is pre-trained and I need to make it non-trainable when training the full model.然而,我的第一个 model 是经过预训练的,我需要在训练完整的 model 时使其不可训练。 How can I achieve this in PyTorch.如何在 PyTorch 中实现这一点。

I am able to concatenate two models using this answer我可以使用这个答案连接两个模型

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

Basically here, I want to load pre-trained modelA and make it non-trainable when training the Ensemble model.基本上在这里,我想加载预训练的模型 A 并在训练 Ensemble modelA时使其不可训练。

You can freeze all parameters of the model you dont want to train, by setting requires_grad to false.您可以通过将requires_grad设置为 false 来冻结您不想训练的 model 的所有参数。 Like this:像这样:

for param in model.parameters():
    param.requires_grad = False

This should work for you.这应该适合你。

Another way is to handle this in your train-loop:另一种方法是在你的火车循环中处理这个:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

So you pass the output of modelA to modelB but you optimize just modelB.因此,您将模型A 的 output 传递给模型B,但您只优化模型B。

One easy way to do that is to detach the output tensor of the model that you don't want to update and it will not backprop gradient to the connected model.一种简单的方法是detach您不想更新的 model 的 output 张量,它不会将梯度反向传播到连接的 model。 In your case, you can simply detach x2 tensor just before concatinating with x1 in the forward function of MyEnsemble model to keep the weight of modelB unchanged.在您的情况下,您可以在与MyEnsemble model 的前向modelB中的x1连接之前简单地detach x2张量,以保持模型 B 的权重不变。

So, the new forward function should be like following:因此,新的前向 function 应该如下所示:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 PyTorch 中的卷积与不可训练的预定义 kernel - Convolution in PyTorch with non-trainable pre-defined kernel 如何将 keras 中的参数设置为不可训练? - How to set parameters in keras to be non-trainable? 如果我们将一个可训练参数与一个不可训练的参数组合在一起,那么原始的可训练参数是否可训练? - If we combine one trainable parameters with a non-trainable parameter, is the original trainable param trainable? 如何在 PyTorch 中返回一个自定义激活 function 的可训练参数? - How to return of one trainable parameters of custom activation function in PyTorch? 如何将特定的 keras 层权重定义为不可训练? - How to define a specific keras layer weight as non-trainable? 如何在 tf.Estimator 检查点中保留不可训练的变量? - How to persist non-trainable variables in tf.Estimator checkpoint? 在Keras有可能有不可训练的层吗? - Is it possible to have non-trainable layer in Keras? BERT 编码器层不可训练 - BERT Encoder layer is non-trainable 如何在 tensorflow 中实现自定义不可训练的卷积过滤器? - How do you implement a custom non-trainable Convolution filter in tensorflow? 一个pickle任意pytorch型号使用lambda功能如何? - How does one pickle arbitrary pytorch models that use lambda functions?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM