[英]TensorFlow: restoring model in a MonitoredSession
I have a model that contains multiple variables including a global step. 我有一个包含多个变量(包括全局步骤)的模型。 I've been able to successfully use a MonitoredSession to save checkpoints and summaries every 100 steps.
我已经能够成功地使用MonitoredSession每100个步骤保存检查点和摘要。 I was expecting the MonitoredSession to automatically restore all my variables when the session is run in multiple passes (based on this documentation), however this does not happen.
当会话以多遍运行(基于此文档)运行时,我期望MonitoredSession自动恢复所有变量,但是不会发生。 If I take a look at the global step after running the training session again, I find that it starts back from zero.
如果在再次运行培训课程后查看全局步骤,则会发现它从零开始。 This is a simplified version of my code without the actual model.
这是我的代码的简化版本,没有实际模型。 Let me know if more code is needed to solve this problem
让我知道是否需要更多代码来解决此问题
train_graph = tf.Graph()
with train_graph.as_default():
# I create some datasets using the Dataset API
# ...
global_step = tf.train.create_global_step()
# Create all the other variables and the model here
# ...
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir='checkpoint/',
save_secs=None,
save_steps=100,
saver=tf.train.Saver(),
checkpoint_basename='model.ckpt',
scaffold=None)
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
save_secs=None,
output_dir='summaries/',
summary_writer=None,
scaffold=None,
summary_op=train_step_summary)
num_steps_hook = tf.train.StopAtStepHook(num_steps=500) # Just for testing
with tf.train.MonitoredSession(
hooks=[saver_hook, summary_hook, num_steps_hook]) as sess:
while not sess.should_stop():
step = sess.run(global_step)
if (step % 100 == 0):
print(step)
sess.run(optimizer)
When I run this code the first time, I get this output 第一次运行此代码时,我得到的输出
0
100
200
300
400
The checkpoint folder at this point has checkpoints for every hundredth step up to 500. If I run the program again I would expect to see the counter start at 500 and the increase up to 900, but instead I just get the same thing and all of my checkpoints get overwritten. 此时,checkpoint文件夹中的每100个步骤都有一个检查点,直到500。如果我再次运行该程序,我希望计数器从500开始并增加到900,但是相反,我得到的是相同的东西,所有我的检查站被覆盖。 Any ideas?
有任何想法吗?
Alright, I figured it out. 好吧,我知道了。 It was actually really simple.
实际上非常简单。 First, it's easier to use a MonitoredTraningSession() instead of a MonitoredSession().
首先,使用MonitoredTraningSession()而不是MonitoredSession()更容易。 This wrapper session takes as an argument 'checkpoint_dir'.
该包装器会话将'checkpoint_dir'作为参数。 I thought that the saver_hook would take care of restoring, but that's not the case.
我以为saver_hook将负责恢复,但事实并非如此。 In order to fix my problem I just had to change the line where I define the session like so:
为了解决我的问题,我只需要更改定义会话的行,如下所示:
with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook], checkpoint_dir='checkpoint'):
It can also be done with the MonitoredSession directly, but you need to set up a session_creator instead. 也可以直接通过MonitoredSession完成,但是您需要设置session_creator。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.