[英]Save model checkpoint only when model shows improvement in TensorFlow
您是否知道在使用实验中包含的Estimator时是否有办法选择保存哪个模型? 因为每个'save_checkpoints_steps',模型都会被保存,但这个模型不一定是最好的。
def model_fn(features, labels, mode, params):
predict = model_predict_()
loss = model_loss()
train_op = model_train_op(loss, mode)
predictions = {"predictions": predict}
return tf.estimator.EstimatorSpec(
mode = mode,
predictions = predictions,
loss = loss,
train_op = train_op,
)
def experiment_fn(run_config, hparams):
estimator = tf.estimator.Estimator(
model_fn = model_fn,
config = run_config,
params = hparams
)
return learn.Experiment(
estimator = estimator,
train_input_fn = train_input_fn,
eval_input_fn = eval_input_fn,
eval_metrics = None,
train_steps = 1000,
)
ex = learn_runner.run(
experiment_fn = experiment_fn,
run_config = run_config,
schedule = "train_and_evaluate",
hparams = hparams
)
输出如下:
INFO:tensorflow:将401的检查点保存到。\\ model.ckpt中。
INFO:tensorflow:global_step / sec:0.157117 INFO:tensorflow:step = 401,loss = 2.95048(636.468 sec)
信息:tensorflow:在2017-09-05-20:06:07开始评估信息:tensorflow:从。\\ model.ckpt-401恢复参数
信息:tensorflow:评估[1/1] INFO:tensorflow:完成评估2017-09-05-20:06:09
INFO:tensorflow:为全局步骤401保存dict:global_step = 401,loss = 7.20411
INFO:tensorflow:验证(步骤401):global_step = 401,loss = 7.20411
信息:tensorflow:训练损失= 2.95048,步骤= 401(315.393秒)
INFO:tensorflow:将451的检查点保存到。\\ model.ckpt中。
信息:tensorflow:在2017-09-05-20:11:32开始评估
信息:tensorflow:从。\\ model.ckpt-451恢复参数
信息:张量流:评估[1/1]
你会看到每次它保存最后一个模型,这不一定是最好的。
为您的训练过程中断的事件保存检查点。 如果您没有检查点,则需要从头开始重新启动。 对于需要数周训练的大型车型来说,这是一个大问题。
一旦您的培训完成并且您对您的模型感到满意(用您的话说,“这是最好的”),您可以使用https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator明确保存它#export_savedmodel 。 调用此方法位于您用于创建Experiemnt
的Estimator
上。 请注意,此方法保存“推理模型”,这意味着所有渐变操作都将从其中删除而不保存。
编辑:在回复尼古拉斯的评论:您可以保存快照周期性除了使用最近期的keep_checkpoint_every_n_hours
选项RunConfig
你创建估计时通过。 如果您发现您的模型在10小时前达到了最佳性能,那么您应该可以从大致那个时间找到快照。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.