简体   繁体   中英

Problem loading Pytorch Tacotron2 model with only the pth file

I've trained a Tacotron2 model, using Mozilla TTS, on a custom dataset. The trainer outputs a pth file and a config.json file. I have difficulty loading the trained model into PyTorch.

from torchaudio.models.tacotron2 import Tacotron2
tacotron2 =Tacotron2()
tacotron2.load_state_dict(torch.load('models/best_model.pth'))

RuntimeError: Error(s) in loading state_dict for Tacotron2: Missing key(s) in state_dict: "embedding.weight", "encoder.convolutions.0.0.weight", "encoder.convolutions.0.0.bias", "encoder.convolutions.0.1.weight", "encoder.convolutions.0.1.bias", "encoder.convolutions.0.1.running_mean", "encoder.convolutions.0.1.running_var", "encoder.convolutions.1.0.weight", "encoder.convolutions.1.0.bias", "encoder.convolutions.1.1.weight", "encoder.convolutions.1.1.bias", "encoder.convolutions.1.1.running_mean", "encoder.convolutions.1.1.running_var", "encoder.convolutions.2.0.weight", "encoder.convolutions.2.0.bias", "encoder.convolutions.2.1.weight", "encoder.convolutions.2.1.bias", "encoder.convolutions.2.1.running_mean", "encoder.convolutions.2.1.running_var", "encoder.lstm.weight_ih_l0", "encoder.lstm.weight_hh_l0", "encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0", "encoder.lstm.weight_ih_l0_reverse", "encoder.lstm.weight_hh_l0_reverse", "encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias _hh_l0_reverse", "decoder.prenet.layers.0.weight", "decoder.prenet.layers.1.weight", "decoder.attention_rnn.weight_ih", "decoder.attention_rnn.weight_hh", "decoder.attention_rnn.bias_ih", "decoder.attention_rnn.bias_hh", "decoder.attention_layer.query_layer.weight", "decoder.attention_layer.memory_layer.weight", "decoder.attention_layer.v.weight", "decoder.attention_layer.location_layer.location_conv.weight", "decoder.attention_layer.location_layer.location_dense.weight", "decoder.decoder_rnn.weight_ih", "decoder.decoder_rnn.weight_hh", "decoder.decoder_rnn.bias_ih", "decoder.decoder_rnn.bias_hh", "decoder.linear_projection.weight", "decoder.linear_projection.bias", "decoder.gate_layer.weight", "decoder.gate_layer.bias", "postnet.convolutions.0.0.weight", "postnet.convolutions.0.0.bias", "postnet.convolutions.0.1.weight", "postnet.convolutions.0.1.bias", "postnet.convolutions.0.1.running_mean", "postnet.convolutions.0.1.running_var", "postnet.convolutions.1.0.weight", "postnet.convolut ions.1.0.bias", "postnet.convolutions.1.1.weight", "postnet.convolutions.1.1.bias", "postnet.convolutions.1.1.running_mean", "postnet.convolutions.1.1.running_var", "postnet.convolutions.2.0.weight", "postnet.convolutions.2.0.bias", "postnet.convolutions.2.1.weight", "postnet.convolutions.2.1.bias", "postnet.convolutions.2.1.running_mean", "postnet.convolutions.2.1.running_var", "postnet.convolutions.3.0.weight", "postnet.convolutions.3.0.bias", "postnet.convolutions.3.1.weight", "postnet.convolutions.3.1.bias", "postnet.convolutions.3.1.running_mean", "postnet.convolutions.3.1.running_var", "postnet.convolutions.4.0.weight", "postnet.convolutions.4.0.bias", "postnet.convolutions.4.1.weight", "postnet.convolutions.4.1.bias", "postnet.convolutions.4.1.running_mean", "postnet.convolutions.4.1.running_var". Unexpected key(s) in state_dict: "config", "model", "optimizer", "scaler", "step", "epoch", "date", "model_loss".

According to the error message, what the load_state_dict() command was expecting was apparently a dictionary with keys being named network parameters like "decoder.attention_rnn.bias_hh" etc, ie the trained parameters and a way to identify them. It seems however that the pth checkpoint is a binarized python dictionary, containing all of the necessary ingredients to resume training (rather than just employ the model). I'm guessing that:

  • "config" are the arguments passed to the model during construction
  • "model" the trained weights
  • "optimizer" the optimizer state
  • "scaler" no idea
  • "step" the trainng step,
  • "epoch" the training epoch,
  • "date" self explanatory and
  • "model_loss" ditto.

Try perhaps

checkpoint = torch.load('models/best_model.pth')
tacotron2.load_state_dict(checkpoint["model"])

and see what happens. If it doesn't work, check the keys of the nested dictionary checkpoint["model"] and explore around.

If you passed any non-default arguments during training, you'll need to replicate them (hint: use the config) when initializing for loading too.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM