简体   繁体   中英

GPT-2 Continue training from checkpoint

I am trying to continue training from a saved checkpoint using the colab setup for GPT-2-simple at:

https://colab.research.google.com/drive/1SvQne5O_7hSdmPvUXl5UzPeG5A6csvRA#scrollTo=aeXshJM-Cuaf

But I just cant get it to work. Loading the saved checkpoint from my googledrive works fine, and I can use it to generate text, but I cant continue training from that checkpoint. In the gpt2.finetune () I am entering restore.from='latest" and overwrite=True , and I have been trying to use both same run_name and different one, and using overwrite=True , and not. I have also tried restarting the runtime in between, as was suggested, but it doesn´t help, I keep getting the following error:

"ValueError: Variable model/wpe already exists, disallowed. Did you mean to set reuse=True 
or reuse=tf.AUTO_REUSE in VarScope?"

I asume that I need to run the gpt2.load_gpt2(sess, run_name='myRun') before continue training, but whenever I have run this first, the gtp2.finetune() throws this error

You don't need to (and can't) run load_gpt2() before finetuning. You instead simply need to give run_name to finetune() . I agree that this is confusing as hell; I had the same trouble.

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
    file_name,
    model_name=model_name,
    checkpoint_dir=checkpoint_dir,
    run_name=run_name,
    steps=25,
)

This will automatically grab the latest checkpoint from your checkpoint/run-name folder, load its weights, and continue training where it left off. You can confirm this by checking the epoch number - it doesn't start again from 0. Eg, if you'd previously trained 25 epochs, it'll start at 26:

Training...

[26 | 7.48] loss=0.49 avg=0.49

Also note that to run finetuning multiple times (or to load another model) you normally have to restart the python runtime. You can instead run this before each finetine command:

tf.reset_default_graph()

I've tryed the following and works fine:

tf.reset_default_graph()
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
          steps=n,
          dataset=file_name,
          model_name='model', 
          print_every=z,
          run_name= 'run_name',
          restore_from='latest',
          sample_every=x,
          save_every=y
          )

You must indicate the same 'run_name' as the model you want to resume training and hp restore_from = 'latest'

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM