简体   繁体   中英

How train the pytorch model with another freeze model?

I train the pytorch model(name = A) with pre-trained model(name = B).

I want to use two models in series like this: The input is used so that the output comes out through two models, and this output and the target label are compared. Model2 freezes and learns the learning parameters of model1.

https://arxiv.org/ftp/arxiv/papers/1905/1905.01898.pdf figure 1 (b).

And I wrote the code like this:

Class A(nn.Module):
      # MY CODE

Class B(nn.Module):
      # MY CODE

train_model = A()
freeze_model = B()

freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
    param.requires_grad = False

criterion=nn.MSELoss()
optimizer = torch.optim.Adam(train_model.parameters(), lr=learning_rate)

for epoch in range(100):
    ... iteration ...
    out1 = train_model(input1)
    out2 = freeze_model(out1,input2)
    loss = criterion(target,out2)
    ... optimizer.zero_grad, loss backward, ...
    optimizer.step()

However, The first iteration comes out with a real value, After optimizer_Step, model weigth and loss are changed to nan..

I think this is for an optimizer or some other reason. Is there a solution?

i got your question:
you cannot use 2 models on single optimizer, using this

out1 = train_model(input1)
out2 = freeze_model(out1,input2)

I am answering your question but i can just write the signature, you have to make the required changes

Class A(nn.Module):
      # MY CODE

Class B(nn.Module):
      # MY CODE

train_model = A()
freeze_model = B()
freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
    param.requires_grad = False


class Serial(nn.Module):
        def __init__(self):
            super(VGG, self).__init__()
            
            # get the pretrained VGG19 network
            self.train_1 = train_model
            self.freeze_1 = freeze_model
           

        def forward(self, x,input_2):
            x = self.train_1(x)
            x = self.freeze_1(x,input_2)
            return x

serial_model = Serial()       
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(serial_model.parameters(), lr=learning_rate)

for epoch in range(100):
    ... iteration ...
    out = serial_model(input1,input2)
    loss = criterion(target,out)
    ... optimizer.zero_grad, loss backward, ...
    optimizer.step()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM