简体   繁体   English

如何在tf.train优化器中创建Checkpoint存储时刻和其他相关变量

[英]How to make the Checkpoint store moments and other relevant variables in tf.train Optimizers

I encountered a problem when my code stopped for some reason on my machine, so I had to restart my code and continue the training process by loading the latest checkpoint file. 我的代码在我的机器上由于某种原因停止时遇到了问题,所以我不得不重新启动代码并通过加载最新的检查点文件继续训练过程。

I found that the performance is not consistent before and after the checkpoint that I loaded and the performance dropped a lot. 我发现在我加载的检查点之前和之后性能不一致,性能下降了很多。

So, since my code uses tf.train.AdamOptimizer , I guess that the checkpoint doesn't store the moment vectors and the gradients in the previous steps, and when I load the checkpoint the moment vectors are initialized as zeros. 因此,由于我的代码使用tf.train.AdamOptimizer ,我猜测检查点不存储前一步中的矩向量和渐变,当我加载检查点时,矩向量被初始化为零。

Am I correct? 我对么?

Is there any method that can help store relevant vectors for the Adamopotimizer in the checkpoints so that if my machine is down again, restarting from the latest checkpoint will not influence anything? 有没有任何方法可以帮助在检查点中存储Adamopotimizer的相关向量,这样如果我的机器再次关闭,从最新检查点重新启动不会影响任何东西?

Thanks! 谢谢!

Out of curiosity, I checked if it is true and everything seems to be working just fine: all variables are presented in the checkpoint and restored properly. 出于好奇,我检查了它是否属实并且一切似乎都正常工作:所有变量都显示在检查点中并正确恢复。 See for yourself: 你自己看:

import tensorflow as tf
import sys
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as inch


ckpt_path = "./tmp/model.ckpt"
shape = (2, 2)

def _print_all():
  for v in tf.all_variables():
    print('%20s' % v.name, v.eval())

def _model():
    a = tf.placeholder(tf.float32, shape)
    with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
      x = tf.get_variable('x', shape)

    loss = tf.matmul(a, tf.layers.batch_normalization(x))
    step = tf.train.AdamOptimizer(0.00001).minimize(loss)
    return a, step

def train():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for i in range(10):
        _ = sess.run(step, feed_dict= {a:np.random.rand(*shape)})

      _print_all()
      print(saver.save(sess, ckpt_path))
      _print_all()


def check():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      _print_all()
      saver.restore(sess, ckpt_path)
      _print_all()


def checkpoint_list_vars(chpnt):
  """
  Given path to a checkpoint list all variables available in the checkpoint
  """
  from tensorflow.contrib.framework.python.framework import checkpoint_utils
  var_list = checkpoint_utils.list_variables(chpnt)
#   for v in var_list: print(v, var_val(v[0]))
#   for v in var_list: print(v)
  var_val('')

  return var_list

def var_val(name):
    inch.print_tensors_in_checkpoint_file(ckpt_path, name, True)

if 'restore' in sys.argv:
    check()
elif 'checkpnt' in sys.argv:
    checkpoint_list_vars(ckpt_path)
else:
    train()

Store it as test.py and run 将其存储为test.py并运行

>> python test.py
>> python test.py checkpnt
>> python test.py restore

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM