I train the pytorch model(name = A) with pre-trained model(name = B).
I want to use two models in series like this: The input is used so that the output comes out through two models, and this output and the target label are compared. Model2 freezes and learns the learning parameters of model1.
https://arxiv.org/ftp/arxiv/papers/1905/1905.01898.pdf figure 1 (b).
And I wrote the code like this:
Class A(nn.Module):
# MY CODE
Class B(nn.Module):
# MY CODE
train_model = A()
freeze_model = B()
freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
param.requires_grad = False
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(train_model.parameters(), lr=learning_rate)
for epoch in range(100):
... iteration ...
out1 = train_model(input1)
out2 = freeze_model(out1,input2)
loss = criterion(target,out2)
... optimizer.zero_grad, loss backward, ...
optimizer.step()
However, The first iteration comes out with a real value, After optimizer_Step, model weigth and loss are changed to nan..
I think this is for an optimizer or some other reason. Is there a solution?
i got your question:
you cannot use 2 models on single optimizer, using this
out1 = train_model(input1)
out2 = freeze_model(out1,input2)
I am answering your question but i can just write the signature, you have to make the required changes
Class A(nn.Module):
# MY CODE
Class B(nn.Module):
# MY CODE
train_model = A()
freeze_model = B()
freeze_model.load_state_dict(torch.load("bestmodel.pth"))
for param in freeze_model.parameters():
param.requires_grad = False
class Serial(nn.Module):
def __init__(self):
super(VGG, self).__init__()
# get the pretrained VGG19 network
self.train_1 = train_model
self.freeze_1 = freeze_model
def forward(self, x,input_2):
x = self.train_1(x)
x = self.freeze_1(x,input_2)
return x
serial_model = Serial()
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(serial_model.parameters(), lr=learning_rate)
for epoch in range(100):
... iteration ...
out = serial_model(input1,input2)
loss = criterion(target,out)
... optimizer.zero_grad, loss backward, ...
optimizer.step()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.