简体   繁体   中英

early stopping in tensorflow object detection api

I am trying to implement early stopping in TF OD API. I used this code .

Here is my EarlyStoppingHook (is it essentially just a copy from the above code):

class EarlyStoppingHook(session_run_hook.SessionRunHook):
    """Hook that requests stop at a specified step."""

    def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                 mode='auto'):
        """
        """
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        self.max_wait = 0
        self.ind = 0
        if mode not in ['auto', 'min', 'max']:
            logging.warning('EarlyStopping mode %s is unknown, '
                            'fallback to auto mode.', mode, RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

        self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def begin(self):
        # Convert names to tensors if given
        graph = tf.get_default_graph()
        self.monitor = graph.as_graph_element(self.monitor)
        if isinstance(self.monitor, tf.Operation):
            self.monitor = self.monitor.outputs[0]

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return session_run_hook.SessionRunArgs(self.monitor)

    def after_run(self, run_context, run_values):
        self.ind += 1

        current = run_values.results

        if self.ind % 200 == 0:
          print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            if self.max_wait < self.wait:
              self.max_wait = self.wait
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                run_context.request_stop()

And I use the class like this:


early_stopping_hook = EarlyStoppingHook(
      monitor='total_loss', 
      patience=2000)

train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])

What I don't understand is what is total_loss? Is this val loss or train loss? Also I don't understand where these losses ('total_loss', 'loss_1', 'loss_2') are defined.

So, here is what worked for me

from matplotlib import pyplot as plt
import numpy as np

import collections
import os 

_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'

def _summaries(eval_dir):
  """Yields `tensorflow.Event` protos from event files in the eval dir.
  Args:
    eval_dir: Directory containing summary files with eval metrics.
  Yields:
    `tensorflow.Event` object read from the event files.
  """
  if tf.compat.v1.gfile.Exists(eval_dir):
    for event_file in tf.compat.v1.gfile.Glob(
        os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
      for event in tf.compat.v1.train.summary_iterator(event_file):
        yield event

def read_eval_metrics(eval_dir):
  """Helper to read eval metrics from eval summary files.
  Args:
    eval_dir: Directory containing summary files with eval metrics.
  Returns:
    A `dict` with global steps mapping to `dict` of metric names and values.
  """
  eval_metrics_dict = collections.defaultdict(dict)
  for event in _summaries(eval_dir):
    if not event.HasField('summary'):
      continue
    metrics = {}
    for value in event.summary.value:
      if value.HasField('simple_value'):
        metrics[value.tag] = value.simple_value
    if metrics:
      eval_metrics_dict[event.step].update(metrics)
  return collections.OrderedDict(
      sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
  
met_dict_2 = read_eval_metrics('/content/gdrive2/My Drive/models/retinanet/eval_0')
x = []
y = []
for k, v in met_dict_2.items():
    x.append(k)
    y.append(v['Loss/total_loss'])

read_eval_metrics function returns dictionary which keys are iteration number and values are different metrics and losses computer at that evaluation step. But you can also use this function for train event files. You just need to change the path.

Example of one key value pair from returned dictionary.

(4988, {'DetectionBoxes_Precision/Precision@.50IOU': 0.12053315341472626,
               'DetectionBoxes_Precision/mAP': 0.060865387320518494,
               'DetectionBoxes_Precision/mAP (large)': 0.07213596999645233,
               'DetectionBoxes_Precision/mAP (medium)': 0.062120337039232254,
               'DetectionBoxes_Precision/mAP (small)': 0.02642354555428028,
               'DetectionBoxes_Precision/mAP@.50IOU': 0.11469704657793045,
               'DetectionBoxes_Precision/mAP@.75IOU': 0.06001879647374153,
               'DetectionBoxes_Recall/AR@1': 0.13470394909381866,
               'DetectionBoxes_Recall/AR@10': 0.20102562010288239,
               'DetectionBoxes_Recall/AR@100': 0.2040158212184906,
               'DetectionBoxes_Recall/AR@100 (large)': 0.2639017701148987,
               'DetectionBoxes_Recall/AR@100 (medium)': 0.20173722505569458,
               'DetectionBoxes_Recall/AR@100 (small)': 0.10018187761306763,
               'Loss/classification_loss': 1.0127471685409546,
               'Loss/localization_loss': 0.3542810380458832,
               'Loss/regularization_loss': 0.708609938621521,
               'Loss/total_loss': 2.0756208896636963,
               'learning_rate': 0.0006235376931726933,
               'loss': 2.0756208896636963})

So I ended up setting monitor argument to 'Loss/total_loss' instead of 'total_loss' in EarlyStoppingHook.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM