简体   繁体   English

仅当模型显示TensorFlow的改进时才保存模型检查点

[英]Save model checkpoint only when model shows improvement in TensorFlow

Do you know if there is a way to chose which model is saved when using Estimator wrapped in an experiment? 您是否知道在使用实验中包含的Estimator时是否有办法选择保存哪个模型? Because every 'save_checkpoints_steps', the model is saved but this model is not necessarily the best. 因为每个'save_checkpoints_steps',模型都会被保存,但这个模型不一定是最好的。

def model_fn(features, labels, mode, params):
    predict = model_predict_()
    loss = model_loss()
    train_op = model_train_op(loss, mode)       
    predictions = {"predictions": predict}

    return tf.estimator.EstimatorSpec(
        mode = mode,
        predictions = predictions,
        loss = loss,
        train_op = train_op,
    )

def experiment_fn(run_config, hparams):
    estimator = tf.estimator.Estimator(
        model_fn = model_fn, 
        config = run_config,
        params = hparams
    )

    return learn.Experiment(
      estimator = estimator,
      train_input_fn = train_input_fn,
      eval_input_fn = eval_input_fn,
      eval_metrics = None,
      train_steps = 1000,
    )

ex = learn_runner.run(
        experiment_fn = experiment_fn,
        run_config = run_config,
        schedule = "train_and_evaluate",
        hparams =  hparams
)

the output is as follow: 输出如下:

INFO:tensorflow:Saving checkpoints for 401 into .\\model.ckpt. INFO:tensorflow:将401的检查点保存到。\\ model.ckpt中。

INFO:tensorflow:global_step/sec: 0.157117 INFO:tensorflow:step = 401, loss = 2.95048 (636.468 sec) INFO:tensorflow:global_step / sec:0.157117 INFO:tensorflow:step = 401,loss = 2.95048(636.468 sec)

INFO:tensorflow:Starting evaluation at 2017-09-05-20:06:07 INFO:tensorflow:Restoring parameters from .\\model.ckpt-401 信息:tensorflow:在2017-09-05-20:06:07开始评估信息:tensorflow:从。\\ model.ckpt-401恢复参数

INFO:tensorflow:Evaluation [1/1] INFO:tensorflow:Finished evaluation at 2017-09-05-20:06:09 信息:tensorflow:评估[1/1] INFO:tensorflow:完成评估2017-09-05-20:06:09

INFO:tensorflow:Saving dict for global step 401: global_step = 401, loss = 7.20411 INFO:tensorflow:为全局步骤401保存dict:global_step = 401,loss = 7.20411

INFO:tensorflow:Validation (step 401): global_step = 401, loss = 7.20411 INFO:tensorflow:验证(步骤401):global_step = 401,loss = 7.20411

INFO:tensorflow:training loss = 2.95048, step = 401 (315.393 sec) 信息:tensorflow:训练损失= 2.95048,步骤= 401(315.393秒)

INFO:tensorflow:Saving checkpoints for 451 into .\\model.ckpt. INFO:tensorflow:将451的检查点保存到。\\ model.ckpt中。

INFO:tensorflow:Starting evaluation at 2017-09-05-20:11:32 信息:tensorflow:在2017-09-05-20:11:32开始评估

INFO:tensorflow:Restoring parameters from .\\model.ckpt-451 信息:tensorflow:从。\\ model.ckpt-451恢复参数

INFO:tensorflow:Evaluation [1/1] 信息:张量流:评估[1/1]

You see that every time it saves the last model, which is not necessarily the best. 你会看到每次它保存最后一个模型,这不一定是最好的。

Checkpoints are saved for the event that your training process is interrupted. 为您的训练过程中断的事件保存检查点。 If you don't have checkpoints you will need to restart from scratch. 如果您没有检查点,则需要从头开始重新启动。 This is a big issue for big models that take weeks to train. 对于需要数周训练的大型车型来说,这是一个大问题。

Once your training is done and you are satisfied with your model (in your words, "it is the best"), you can save it explicitly using https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_savedmodel . 一旦您的培训完成并且您对您的模型感到满意(用您的话说,“这是最好的”),您可以使用https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator明确保存它#export_savedmodel Call this method is on the Estimator that you used to create your Experiemnt . 调用此方法位于您用于创建ExperiemntEstimator上。 Note that this method saves "model for inference" meaning that all the gradient ops will be stripped form it and not saved. 请注意,此方法保存“推理模型”,这意味着所有渐变操作都将从其中删除而不保存。

EDIT: In Reply to Nicolas's comment: You can save snapshots periodically in addition to the most recent ones using keep_checkpoint_every_n_hours option to RunConfig that you pass when creating an estimator. 编辑:在回复尼古拉斯的评论:您可以保存快照周期性除了使用最近期的keep_checkpoint_every_n_hours选项RunConfig你创建估计时通过。 If you then find that your model achieved best performance 10 hours ago, you should be able to find a snapshot from roughly that time. 如果您发现您的模型在10小时前达到了最佳性能,那么您应该可以从大致那个时间找到快照。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM