简体   繁体   English

tf.train.Saver-在不同的机器上加载最新的检查点

[英]tf.train.Saver - Load latest checkpoint on different machine

I have a trained model, which was saved using tf.train.Saver , generating 4 relevant files 我有一个训练有素的模型,已使用tf.train.Saver保存了该模型,并生成了4个相关文件

  • checkpoint
  • model_iter-315000.data-00000-of-00001
  • model_iter-315000.index
  • model_iter-315000.meta

Now since it was generated through a docker container, the paths on the machine itself and the docker are different, as if we are working on two different machines. 现在,由于它是通过docker容器生成的,因此机器本身和docker上的路径是不同的,就像我们在两台不同的机器上工作一样。

I am trying to load the saved model, outside of the container. 我正在尝试将保存的模型加载到容器外部。

When I'm running the following 当我运行以下内容时

sess = tf.Session()
saver = tf.train.import_meta_graph('path_to_.meta_file_on_new_machine')  # Works
saver.restore(sess, tf.train.latest_checkpoint('path_to_ckpt_dir_on_new_machine')  # Fails

And the error is 错误是

tensorflow.python.framework.errors_impl.NotFoundError: PATH_ON_OLD_MACHINE ; tensorflow.python.framework.errors_impl.NotFoundError: PATH_ON_OLD_MACHINE ; No such file or directory 没有相应的文件和目录

Even though I supply the new path when calling tf.train.latest_checkpoint I get the error, which displays the path on the old one. 即使我在调用tf.train.latest_checkpoint时提供了新路径, tf.train.latest_checkpoint出现错误,该错误会在旧路径上显示该路径。

How can I solve this? 我该如何解决?

The "checkpoint" file is an index file, which itself has paths embedded in it. “检查点”文件是一个索引文件,它本身具有嵌入的路径。 Open it up in a text editor and change the paths to the correct new one. 在文本编辑器中将其打开,然后将路径更改为正确的新路径。

Alternatively, use tf.train.load_checkpoint() to load a specific checkpoint and not rely on TensorFlow finding the latest one for you. 或者,使用tf.train.load_checkpoint()加载特定的检查点,而不依赖TensorFlow为您找到最新的检查点。 In this case it won't refer to the "checkpoint" file and the different paths will not be a problem. 在这种情况下,它不会引用“检查点”文件,并且不同的路径也不会成为问题。

Or write a small script to modify the contents of "checkpoint". 或编写一个小的脚本来修改“检查点”的内容。

If you open the checkpoint file, you will see something like that : 如果打开checkpoint文件,您将看到类似以下内容:

model_checkpoint_path: "/PATH/ON/OLD/MACHINE/model.ckpt-315000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-300000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-285000"
[...]

Just remove the /PATH/ON/OLD/MACHINE/ , or replace it by the /PATH/ON/NEW/MACHINE/ , and you're good to go. 只需删除/PATH/ON/OLD/MACHINE/ ,或将其替换为/PATH/ON/NEW/MACHINE/ ,就可以了。

Edit : In the future, when creating your tf.train.Saver , you should use the save_relative_paths option. 编辑 :将来,在创建tf.train.Saver ,应使用save_relative_paths选项。 Quoting the doc : 引用文档

save_relative_paths : If True, will write relative paths to the checkpoint state file. save_relative_paths :如果为True,则将相对路径写入检查点状态文件。 This is needed if the user wants to copy the checkpoint directory and reload from the copied directory. 如果用户要复制检查点目录并从复制的目录重新加载,则需要这样做。

Here's an approach that doesn't require editing the checkpoint file or manually looking inside the checkpoint directory. 这是一种不需要编辑检查点文件或手动查看检查点目录内部的方法。 If we know the name of the checkpoint prefix, we can use regex and the assumption that tensorflow writes the latest checkpoint in the first line in the checkpoint file: 如果我们知道检查点前缀的名称,则可以使用regex和tensorflow在checkpoint文件的第一行中写入最新检查点的假设:

import tensorflow as tf
import os
import re


def latest_checkpoint(ckpt_dir, ckpt_prefix="model.ckpt", return_relative=True):
    if return_relative:
        with open(os.path.join(ckpt_dir, "checkpoint")) as f:
            text = f.readline()
        pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+")
        basename = pattern.findall(text)[0]
        return os.path.join(ckpt_dir, basename)
    else:
        return tf.train.latest_checkpoint(ckpt_dir)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM