简体   繁体   中英

Continue training of a custom tf.Estimator with AdamOptimizer

I created a custom tf.Estimator whose weights I'm training using the tf.train.AdamOptimizer. When I continue training of an existing model, I observe a steep change in the metrics at the start of the continued training in Tensorboard. After a few steps, the metrics stabilise. The behaviour looks similar to the initial transients when training a model. The behaviour is the same if I continue training on the same Estimator instance, or if I recreate the estimator from a checkpoint. I suspect that the moving averages and/or the bias correction factor are reset when restarting the training. The model weights themselves seem to be properly restored, as the metrics do continue from where they settled before, only the effective learning rate seems to be too high.

Previous Stack-Overflow answers seem to suggest that these auxiliary learning parameters should be stored with the checkpoints together with the model weights. So what am I doing wrong here? How can I control restoring of these auxiliary variables? I would like to be able to continue training as if it had never been stopped. However, other people sometimes seem look for the opposite control, to completely reset the optimizer without resetting the model weights. An answer that shows how both effects can be achieved would probably most helpful.

Here is a sketch of my model_fn :

def model_fn(features, labels, mode, params):
    inputs = features['inputs']
    logits = create_model(inputs, training=mode == tf.estimator.ModeKeys.TRAIN)

    if mode == tf.estimator.ModeKeys.PREDICT:
        ...

    if mode == tf.estimator.ModeKeys.TRAIN:
        outputs = labels['outputs']

        loss = tf.losses.softmax_cross_entropy(
            tf.one_hot(outputs,tf.shape(inputs)[-1]),
            logits,
#            reduction=tf.losses.Reduction.MEAN,
        )
        optimizer = tf.train.AdamOptimizer(learning_rate=params.learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())

        accuracy = tf.metrics.accuracy(
            labels = outputs,
            predictions = tf.argmax(logits, axis=-1),
        )

        tf.summary.histogram('logits',logits)
        tf.summary.scalar('accuracy', accuracy[1])
        tf.summary.scalar('loss', loss)

        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=loss,
            train_op=train_op)

    if mode == tf.estimator.ModeKeys.EVAL:
        ...

    raise ValueError(mode)

The training step is called as follows:

cfg = tf.estimator.RunConfig(
    save_checkpoints_secs = 5*60,  # Save checkpoints every 1 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
    save_summary_steps = 10,
    log_step_count_steps = 100,
)
estimator = tf.estimator.Estimator(
    model_fn = model_fn,
    params = dict(
        learning_rate = 1e-3,
    ),
    model_dir = model_dir,
    config=cfg,
)
# train for the first time
estimator.train(
    input_fn=train_input_fn,
)
# ... at some later time, train again
estimator.train(
    input_fn=train_input_fn,
)

EDIT:

The documentation of the warm_start_from argument of tf.estimator.Estimator and tf.estimator.WarmStartSettings are not entirely clear what exactly will happen in the default case, as I am using in the example above. However, the documentation of [ tf.train.warm_start ] ( https://www.tensorflow.org/api_docs/python/tf/train/warm_start ) seems to suggest that in the default case, all TRAINABLE_VARIABLES will be warm-started, which

excludes variables such as accumulators and moving statistics from batch norm

Indeed, I find Adam's accumulator variables in VARIABLES , but not in TRAINABLE_VARIABLES . These documentation pages also state how to change the list of warm-started variables, to either a list of tf.Variable instances, or a list of their names. However, one question remains: How do I create one of those lists in advance, given that with tf.Estimator , I have no graph to collect those variables/their names from?

EDIT2:

The source-code of warm_start highlights an undocumented feature: The list of variable names is in fact a list of regexes, to be matched against GLOBAL_VARIABLES. Thus, one may use

    warm_start_from=tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=str(model_dir),
    #    vars_to_warm_start=".*", # everything in TRAINABLE_VARIABLES - excluding optimiser params 
        vars_to_warm_start=[".*"], # everything in GLOBAL_VARIABLES - including optimiser params 
    ),

to load all variables. However, even with that, the spikes in the summary stats remain. With that, I'm completely at a loss now what is going on.

By default metrics are added to the local variables and metric variables collections, and these are not checkpointed by default.

If you want to include them in checkpoints, you can either append metric variables to the global variables collection:

tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_collection(tf.GraphKeys.METRIC_VARIABLES))

Or you can return a Scaffold with a custom Saver set , passing the variables to checkpoint to Saver 's var_list argument . This defaults to the global variables collection.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM