簡體   English   中英

如何連接 2 個 pytorch 模型並使第一個模型在 PyTorch 中不可訓練

[英]How to concatenate 2 pytorch models and make the first one non-trainable in PyTorch

我有兩個網絡,我需要為完整的 model 連接它們。 然而,我的第一個 model 是經過預訓練的,我需要在訓練完整的 model 時使其不可訓練。 如何在 PyTorch 中實現這一點。

我可以使用這個答案連接兩個模型

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x):
        x1 = self.modelA(x)
        x2 = self.modelB(x1)
        return x2

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)

基本上在這里,我想加載預訓練的模型 A 並在訓練 Ensemble modelA時使其不可訓練。

您可以通過將requires_grad設置為 false 來凍結您不想訓練的 model 的所有參數。 像這樣:

for param in model.parameters():
    param.requires_grad = False

這應該適合你。

另一種方法是在你的火車循環中處理這個:

modelA = MyModelA()
modelB = MyModelB()

criterionB = nn.MSELoss()
optimizerB = torch.optim.Adam(modelB.parameters(), lr=0.001)

for epoch in range(epochs):
    for samples, targets in dataloader:
        optimizerB.zero_grad()

        x = modelA.train()(samples)
        predictions = modelB.train()(samples)
    
        loss = criterionB(predictions, targets)
        loss.backward()
        optimizerB.step()

因此,您將模型A 的 output 傳遞給模型B,但您只優化模型B。

一種簡單的方法是detach您不想更新的 model 的 output 張量,它不會將梯度反向傳播到連接的 model。 在您的情況下,您可以在與MyEnsemble model 的前向modelB中的x1連接之前簡單地detach x2張量,以保持模型 B 的權重不變。

因此,新的前向 function 應該如下所示:

def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2.detach()), dim=1)  # Detaching x2, so modelB wont be updated
        x = self.classifier(F.relu(x))
        return x

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM