简体   繁体   中英

tensorflow : restore from checkpoint for continue training

in this case ,i want to continue train my model from checkpoint.i use the cifar-10 example and did a little change in cifar-10_train.py like below,they are almost the same,except i want to restore from checkpoint: i replaced cifar-10 by md.

"""

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os.path
import time
import numpy

import tensorflow.python.platform
from tensorflow.python.platform import gfile

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

import md
"""



"""


FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/root/test/INT/tbc',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 60000,        # 55000 steps per epoch
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '/root/test/INT/',
                           """If specified, restore this pretrained model """
                           """before beginning any training.""")





def error_rate(predictions, labels):
  """Return the error rate based on dense predictions and 1-hot labels."""
  return 100.0 - (
      100.0 *
      numpy.sum(numpy.argmax(predictions, 0) == numpy.argmax(labels, 0)) /
      predictions.shape[0])






def train():
  """Train MD65500 for a number of steps."""
  with tf.Graph().as_default():
    # global_step = tf.Variable(0, trainable=False)

    global_step = tf.get_variable(
        'global_step', [],
        initializer=tf.constant_initializer(0), trainable=False)






    # Get images and labels for CIFAR-10.
    images, labels = md.distorted_inputs()


    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = md.inference(images)

    # Calculate loss.
    loss = md.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = md.train(loss, global_step)

    # Predictions for the minibatch. there is no validation set or test set.
    # train_prediction = tf.nn.softmax(logits)
    train_prediction = logits

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Start running operations on the Graph.
    # sess = tf.Session(config=tf.ConfigProto(
        # log_device_placement=FLAGS.log_device_placement))
    # sess.run(init)

    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    # sess.run(init)


    if FLAGS.pretrained_model_checkpoint_path:
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      # variables_to_restore = tf.get_collection(
          # slim.variables.VARIABLES_TO_RESTORE)


      variable_averages = tf.train.ExponentialMovingAverage(
          md.MOVING_AVERAGE_DECAY)
      variables_to_restore = {}
      for v in tf.all_variables():
        if v in tf.trainable_variables():
            restore_name = variable_averages.average_name(v)
        else:
            restore_name = v.op.name
        variables_to_restore[restore_name] = v


      ckpt = tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path)
      if ckpt and ckpt.model_checkpoint_path:

        # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, ckpt.model_checkpoint_path)
        print('%s: Pre-trained model restored from %s' %
              (datetime.now(), ckpt.model_checkpoint_path))
        # print("variables_to_restore")
        # print(variables_to_restore)

      else:
        sess.run(init)










    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph)       #####graph_def=sess.graph_def)

    # tf.add_to_collection('train_op', train_op)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value, predictions = sess.run([train_op, loss, train_prediction])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 100 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))
        # print('Minibatch error: %.5f%%' % error_rate(predictions, labels))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)


def main(argv=None):  # pylint: disable=unused-argument
  # md.maybe_download()
  # if gfile.Exists(FLAGS.train_dir):
    # gfile.DeleteRecursively(FLAGS.train_dir)
  # gfile.MakeDirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
  tf.app.run()

when i run the code,errors like this:

[root@bogon md try]# pythonnew mdtbc_3.py
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcublas.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcudnn.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcufft.so locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:105] successfully opened CUDA library libcurand.so locally
Filling queue with 4000 CIFAR images before starting to train. This will take a few minutes.
I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:900] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties: 
name: GeForce GTX 980 Ti
major: 5 minor: 2 memoryClockRate (GHz) 1.228
pciBusID 0000:01:00.0
Total memory: 6.00GiB
Free memory: 5.78GiB
I tensorflow/core/common_runtime/gpu/gpu_init.cc:126] DMA: 0 
I tensorflow/core/common_runtime/gpu/gpu_init.cc:136] 0:     Y 
I tensorflow/core/common_runtime/gpu/gpu_device.cc:755] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 980 Ti, pci bus id: 0000:01:00.0)
2016-08-30 17:12:48.883303: Pre-trained model restored from /root/test/INT/model.ckpt-59999
WARNING:tensorflow:When passing a `Graph` object, please use the `graph` named argument instead of `graph_def`.
Traceback (most recent call last):
    File "mdtbc_3.py", line 195, in <module>
        tf.app.run()
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
        sys.exit(main(sys.argv))
    File "mdtbc_3.py", line 191, in main
        train()
    File "mdtbc_3.py", line 160, in train
        _, loss_value, predictions = sess.run([train_op, loss, train_prediction])
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 340, in run
        run_metadata_ptr)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 564, in _run
        feed_dict_string, options, run_metadata)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 637, in _do_run
        target_list, options, run_metadata)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 659, in _do_call
        e.code)
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value conv2/weights
     [[Node: conv2/weights/read = Identity[T=DT_FLOAT, _class=["loc:@conv2/weights"], _device="/job:localhost/replica:0/task:0/cpu:0"](conv2/weights)]]
Caused by op u'conv2/weights/read', defined at:
    File "mdtbc_3.py", line 195, in <module>
        tf.app.run()
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
        sys.exit(main(sys.argv))
    File "mdtbc_3.py", line 191, in main
        train()
    File "mdtbc_3.py", line 77, in train
        logits = md.inference(images)
    File "/root/test/md try/md.py", line 272, in inference
        stddev=0.1, wd=0.0)
    File "/root/test/md try/md.py", line 114, in _variable_with_weight_decay
        tf.truncated_normal_initializer(stddev=stddev))
    File "/root/test/md try/md.py", line 93, in _variable_on_cpu
        var = tf.get_variable(name, shape, initializer=initializer)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 339, in get_variable
        collections=collections)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 262, in get_variable
        collections=collections, caching_device=caching_device)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 158, in get_variable
        dtype=variable_dtype)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 209, in __init__
        dtype=dtype)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 318, in _init_from_args
        self._snapshot = array_ops.identity(self._variable, name="read")
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 609, in identity
        return _op_def_lib.apply_op("Identity", input=input, name=name)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
        op_def=op_def)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2154, in create_op
        original_op=self._default_original_op, op_def=op_def)
    File "/usr/local/pythonnew/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1154, in __init__
        self._traceback = _extract_stack()

when i uncomment the line 107 "sess.run(init)" ,it runs perfectly,but a initialised model,just a new one from sctrach. i want to restore variables from checkpoint , and continue my training.i want to restore.

Without having the rest of your code handy, I'd say that the following part is problematic:

for v in tf.all_variables():
    if v in tf.trainable_variables():
        restore_name = variable_averages.average_name(v)
    else:
        restore_name = v.op.name
    variables_to_restore[restore_name] = v

Because you specify a list of variables you want to restore here, but you exclude some (ie the v.op.name for the ones in the trainable vars). That will change the name of the variable in the net that throws the error (again, without the rest of the code, I cannot really say), st one (or more) vars are not restored properly. Two approaches (which are not very sophisticated) will help you here:

  1. If you do not store all variables, do an initialization first, and then restore the variables you have actually stored. This makes sure that tensors you do not really care about get initialized none the less
  2. TF is very efficient when it comes to storing nets. If in doubt, store all variables ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM