The documentation writes that a list of checkpoint paths should be passed to it, but how to get the list? By hard coding? No, it's a silly practice. By parsing the protocol buffer file (a file named as checkpoint
in your model directory)? But tensorflow does not implement a parser, does it? So do I have to implement one by myself? Do you have a good practice to get the checkpoint paths list?
I raise this question because these days I am troubled by one thing. As you know, a days-long training may crash for some reason, and I have to recover it from the latest checkpoint. Recovering training is easy, since I just need to write the following code:
restorer = tf.train.Saver()
restorer.restore(sess, latest_checkpoint)
I can hard code latest_checkpoint
, or somewhat wiser, use tf.train.latest_checkpoint()
.
However, a problem arises after I recover the training. Those old checkpoints files that are created before crash are left there. The Saver only manages the checkpoint files created in one run. I hope it could also manage the previously created checkpoints files so they would be automatically deleted, and I don't have to manually delete them every time. I think such repeating work is really silly.
Then I find the recover_last_checkpoints
method in class tf.train.Saver()
, which allows Saver to manage old checkpoints. But it's not handy to use. So is there any good solution?
As mentioned by @isarandi in a comment, the easiest way is to first recover all checkpoint paths using get_checkpoint_state
followed by all_model_checkpoint_paths
, which is basically an undocumented feature. You can then restore your latest state as such:
states = tf.train.get_checkpoint_state(your_checkpoint_dir)
checkpoint_paths = states.all_model_checkpoint_paths
saver.recover_last_checkpoints(checkpoint_paths)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.