简体   繁体   English

使用预训练模型和配置文件时如何停止基于损失的训练?

[英]How to stop training based on loss when using Pre-trained model and Configuration file?

I am using a Faster RCNN model to train an object detector, using the Pipeline configuration file.我正在使用更快的 RCNN 模型来训练对象检测器,使用管道配置文件。 I know that training can be stopped by simply cancelling directly (ctrl+c).我知道可以通过直接取消(ctrl+c)来停止训练。 I want the training to stop automatically based on Loss value.我希望训练根据损失值自动停止。 How can this be done?如何才能做到这一点? I am aware that keras callbacks can be used when monitoring epochs.我知道在监视纪元时可以使用 keras 回调。 Is there any such option when using configuration files and pre-trained models (which monitors steps).使用配置文件和预训练模型(监控步骤)时是否有任何此类选项。 Thanks.谢谢。

It might just be a hack, but I found a solution to my question.这可能只是一个黑客,但我找到了我的问题的解决方案。 The Object detector requires tf_slim package to be installed.对象检测器需要安装tf_slim包。 And within the tf_slim package, there is a module called learning.py .tf_slim包中,有一个名为learning.py的模块。 The complete path to this might look something like this: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py Here, in the learning.py , starting Line 764, the code looks something like this:完整路径可能如下所示: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py这里,在learning.py ,从第 764 行开始,代码如下所示:

try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

I wrote a small if statement to check the maximum value for the last five values of the total_loss , and if below a certain threshold (in this case 3), make should_stop True .我写了一个小的if语句来检查total_loss的最后五个值的total_loss ,如果低于某个阈值(在本例中为 3),则使should_stop True This is shown below:这如下所示:

try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

If the loss values are continuously below 3 for five steps, then the training stops.如果损失值连续五步低于 3,则训练停止。 The downside to this is that, the package distribution of tf_slim has to be altered.这样做的缺点是,必须更改tf_slim的包分布。 And every time you work on a new object detection problem, this threshold loss value will change.每次你处理一个新的对象检测问题时,这个阈值损失值都会改变。 A better way would be to use a configuration file where you supply the threshold loss value.更好的方法是使用配置文件,您可以在其中提供阈值损失值。 But I'm stopping here for now.但我现在就停在这里。 If anyone else has a better solution, please share.如果其他人有更好的解决方案,请分享。 I hope this helps someone.我希望这可以帮助别人。 Thank you!谢谢!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM