简体   繁体   中英

How does PyTorch know to which neural network the training loss shall be propagated back if you have multiple neural networks?

I want to train a neural network with the help of two other neural networks, which are already trained and tested. The input of the network that I want to train is simultaniously inputted to the first static network. The output of the of the network that I want to train is inputted to the second static network. The loss shall be computed on the outputs of the static networks and propagated back to the train network.

# Initialization
var_model_statemapper = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 8)])

var_model_panda = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 27)])
var_model_panda.load_state_dict(torch.load("panda.pth"))

var_model_ur5 = NeuralNetwork(8, [('linear', 8), ('relu', None), ('dropout', 0.2), ('linear', 24)])
var_model_ur5.load_state_dict(torch.load("ur5.pth"))

var_loss_function = torch.nn.MSELoss()
var_optimizer = torch.optim.Adam(var_model_statemapper.parameters(), lr=0.001)

# Forward Propagation
var_panda_output = var_model_panda(var_statemapper_input)
var_ur5_output = var_model_ur5(var_statemapper_output)
var_train_loss = var_loss_function(var_panda_output, var_ur5_output)

# Backward Propagation
var_optimizer.zero_grad()
var_train_loss.backward()
var_optimizer.step()

You can see that the "var_model_statemapper" is the network that shall be trained. The networks "var_model_panda" and "var_model_ur5" are initialized and their state_dicts are being read from the according ".pth" files, so these networks need to be static. My main question is, which of the networks is updated in the backward propagation? Just the "var_model_statemapper" or all networks? And if the "var_model_statemapper" isn't updated, how do I achive this? And does PyTorch know which network to update just from the initialization of the optimizer?

Formalizing your pipeline to get a good idea of the setup:

x --- | state_mapper | --> y --- | ur5 | --> ur5_out
 \                                              |
  \                                             ↓
   \--- | panda | --> panda_out ----------- | loss_fn | --> loss

Here is what is happening with lines you provided:

var_optimizer.zero_grad()  # 0.
var_train_loss.backward()  # 1.
var_optimizer.step()       # 2.
  1. Calling zero_grad on an optimizer will clear the cache of all parameter gradients contained in that optimizer. In your case, you have var_optimizer registered with the parameters from var_model_statemapper (the model that you want to optimize).

  2. When you infer loss and backpropagate on it via the backward call, the gradients will propagate through the parameters of all three models.

  3. Then calling step on the optimizer will update the parameters registered in the optimizer you're called it upon. In your case, this means var_optimizer.step() will update all parameters of the model var_model_statemapper alone using the gradients computed in step 1. (namely using the backward call on var_train_loss ).

All in all, your current approach will only update the parameters of var_model_statemapper . Ideally, you can freeze models var_model_panda and var_model_ur5 by setting their parameters' requires_grad flag to False . This will save speed on inference and training since their gradients won't be computed and stored during backpropagation.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM