简体   繁体   中英

tf.train.Saver - Load latest checkpoint on different machine

I have a trained model, which was saved using tf.train.Saver , generating 4 relevant files

  • checkpoint
  • model_iter-315000.data-00000-of-00001
  • model_iter-315000.index
  • model_iter-315000.meta

Now since it was generated through a docker container, the paths on the machine itself and the docker are different, as if we are working on two different machines.

I am trying to load the saved model, outside of the container.

When I'm running the following

sess = tf.Session()
saver = tf.train.import_meta_graph('path_to_.meta_file_on_new_machine')  # Works
saver.restore(sess, tf.train.latest_checkpoint('path_to_ckpt_dir_on_new_machine')  # Fails

And the error is

tensorflow.python.framework.errors_impl.NotFoundError: PATH_ON_OLD_MACHINE ; No such file or directory

Even though I supply the new path when calling tf.train.latest_checkpoint I get the error, which displays the path on the old one.

How can I solve this?

The "checkpoint" file is an index file, which itself has paths embedded in it. Open it up in a text editor and change the paths to the correct new one.

Alternatively, use tf.train.load_checkpoint() to load a specific checkpoint and not rely on TensorFlow finding the latest one for you. In this case it won't refer to the "checkpoint" file and the different paths will not be a problem.

Or write a small script to modify the contents of "checkpoint".

If you open the checkpoint file, you will see something like that :

model_checkpoint_path: "/PATH/ON/OLD/MACHINE/model.ckpt-315000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-300000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-285000"
[...]

Just remove the /PATH/ON/OLD/MACHINE/ , or replace it by the /PATH/ON/NEW/MACHINE/ , and you're good to go.

Edit : In the future, when creating your tf.train.Saver , you should use the save_relative_paths option. Quoting the doc :

save_relative_paths : If True, will write relative paths to the checkpoint state file. This is needed if the user wants to copy the checkpoint directory and reload from the copied directory.

Here's an approach that doesn't require editing the checkpoint file or manually looking inside the checkpoint directory. If we know the name of the checkpoint prefix, we can use regex and the assumption that tensorflow writes the latest checkpoint in the first line in the checkpoint file:

import tensorflow as tf
import os
import re


def latest_checkpoint(ckpt_dir, ckpt_prefix="model.ckpt", return_relative=True):
    if return_relative:
        with open(os.path.join(ckpt_dir, "checkpoint")) as f:
            text = f.readline()
        pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+")
        basename = pattern.findall(text)[0]
        return os.path.join(ckpt_dir, basename)
    else:
        return tf.train.latest_checkpoint(ckpt_dir)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM