繁体   English   中英

如何用另一个冻结 model 训练 pytorch model?

[英]How train the pytorch model with another freeze model?

我使用预训练模型(名称 = B)训练 pytorch 模型(名称 = A)。

我想像这样串联使用两个模型:使用输入,以便通过两个模型输出 output 和目标 label。 Model2 冻结并学习 model1 的学习参数。

https://arxiv.org/ftp/arxiv/papers/1905/1905.01898.pdf图 1 (b)。

我写了这样的代码:

Class A(nn.Module):
      # MY CODE

Class B(nn.Module):
      # MY CODE

train_model = A()
freeze_model = B()

freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
    param.requires_grad = False

criterion=nn.MSELoss()
optimizer = torch.optim.Adam(train_model.parameters(), lr=learning_rate)

for epoch in range(100):
    ... iteration ...
    out1 = train_model(input1)
    out2 = freeze_model(out1,input2)
    loss = criterion(target,out2)
    ... optimizer.zero_grad, loss backward, ...
    optimizer.step()

然而,第一次迭代出来的是一个真实的值,在优化器_Step之后,model权重和损失都变成了nan..

我认为这是出于优化器或其他原因。 有解决办法吗?

我收到了你的问题:
你不能在单个优化器上使用 2 个模型,使用这个

out1 = train_model(input1)
out2 = freeze_model(out1,input2)

我正在回答您的问题,但我可以只写签名,您必须进行所需的更改

Class A(nn.Module):
      # MY CODE

Class B(nn.Module):
      # MY CODE

train_model = A()
freeze_model = B()
freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
    param.requires_grad = False


class Serial(nn.Module):
        def __init__(self):
            super(VGG, self).__init__()
            
            # get the pretrained VGG19 network
            self.train_1 = train_model
            self.freeze_1 = freeze_model
           

        def forward(self, x,input_2):
            x = self.train_1(x)
            x = self.freeze_1(x,input_2)
            return x

serial_model = Serial()       
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(serial_model.parameters(), lr=learning_rate)

for epoch in range(100):
    ... iteration ...
    out = serial_model(input1,input2)
    loss = criterion(target,out)
    ... optimizer.zero_grad, loss backward, ...
    optimizer.step()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM