简体   繁体   中英

How to stop training based on loss when using Pre-trained model and Configuration file?

I am using a Faster RCNN model to train an object detector, using the Pipeline configuration file. I know that training can be stopped by simply cancelling directly (ctrl+c). I want the training to stop automatically based on Loss value. How can this be done? I am aware that keras callbacks can be used when monitoring epochs. Is there any such option when using configuration files and pre-trained models (which monitors steps). Thanks.

It might just be a hack, but I found a solution to my question. The Object detector requires tf_slim package to be installed. And within the tf_slim package, there is a module called learning.py . The complete path to this might look something like this: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py Here, in the learning.py , starting Line 764, the code looks something like this:

try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

I wrote a small if statement to check the maximum value for the last five values of the total_loss , and if below a certain threshold (in this case 3), make should_stop True . This is shown below:

try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

If the loss values are continuously below 3 for five steps, then the training stops. The downside to this is that, the package distribution of tf_slim has to be altered. And every time you work on a new object detection problem, this threshold loss value will change. A better way would be to use a configuration file where you supply the threshold loss value. But I'm stopping here for now. If anyone else has a better solution, please share. I hope this helps someone. Thank you!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM