简体   繁体   English

TensorFlow:在MonitoredSession中恢复模型

[英]TensorFlow: restoring model in a MonitoredSession

I have a model that contains multiple variables including a global step. 我有一个包含多个变量(包括全局步骤)的模型。 I've been able to successfully use a MonitoredSession to save checkpoints and summaries every 100 steps. 我已经能够成功地使用MonitoredSession每100个步骤保存检查点和摘要。 I was expecting the MonitoredSession to automatically restore all my variables when the session is run in multiple passes (based on this documentation), however this does not happen. 当会话以多遍运行(基于文档)运行时,我期望MonitoredSession自动恢复所有变量,但是不会发生。 If I take a look at the global step after running the training session again, I find that it starts back from zero. 如果在再次运行培训课程后查看全局步骤,则会发现它从零开始。 This is a simplified version of my code without the actual model. 这是我的代码的简化版本,没有实际模型。 Let me know if more code is needed to solve this problem 让我知道是否需要更多代码来解决此问题

train_graph = tf.Graph()
with train_graph.as_default():
  # I create some datasets using the Dataset API
  # ...

  global_step = tf.train.create_global_step()

  # Create all the other variables and the model here
  # ...

  saver_hook = tf.train.CheckpointSaverHook(
      checkpoint_dir='checkpoint/',
      save_secs=None,
      save_steps=100,
      saver=tf.train.Saver(),
      checkpoint_basename='model.ckpt',
      scaffold=None)
  summary_hook = tf.train.SummarySaverHook(
      save_steps=100,
      save_secs=None,
      output_dir='summaries/',
      summary_writer=None,
      scaffold=None,
      summary_op=train_step_summary)
  num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing


  with tf.train.MonitoredSession(
      hooks=[saver_hook, summary_hook, num_steps_hook]) as sess:
    while not sess.should_stop():
      step = sess.run(global_step)
      if (step % 100 == 0):
        print(step)
      sess.run(optimizer)

When I run this code the first time, I get this output 第一次运行此代码时,我得到的输出

0
100
200
300
400

The checkpoint folder at this point has checkpoints for every hundredth step up to 500. If I run the program again I would expect to see the counter start at 500 and the increase up to 900, but instead I just get the same thing and all of my checkpoints get overwritten. 此时,checkpoint文件夹中的每100个步骤都有一个检查点,直到500。如果我再次运行该程序,我希望计数器从500开始并增加到900,但是相反,我得到的是相同的东西,所有我的检查站被覆盖。 Any ideas? 有任何想法吗?

Alright, I figured it out. 好吧,我知道了。 It was actually really simple. 实际上非常简单。 First, it's easier to use a MonitoredTraningSession() instead of a MonitoredSession(). 首先,使用MonitoredTraningSession()而不是MonitoredSession()更容易。 This wrapper session takes as an argument 'checkpoint_dir'. 该包装器会话将'checkpoint_dir'作为参数。 I thought that the saver_hook would take care of restoring, but that's not the case. 我以为saver_hook将负责恢复,但事实并非如此。 In order to fix my problem I just had to change the line where I define the session like so: 为了解决我的问题,我只需要更改定义会话的行,如下所示:

with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook], checkpoint_dir='checkpoint'):

It can also be done with the MonitoredSession directly, but you need to set up a session_creator instead. 也可以直接通过MonitoredSession完成,但是您需要设置session_creator。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM