简体   繁体   English

在 tensorflow 对象检测 api 中提前停止

[英]early stopping in tensorflow object detection api

I am trying to implement early stopping in TF OD API.我正在尝试在 TF OD API 中实现提前停止。 I used this code .我使用了这个代码

Here is my EarlyStoppingHook (is it essentially just a copy from the above code):这是我的 EarlyStoppingHook(它本质上只是上述代码的副本):

class EarlyStoppingHook(session_run_hook.SessionRunHook):
    """Hook that requests stop at a specified step."""

    def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                 mode='auto'):
        """
        """
        self.monitor = monitor
        self.patience = patience
        self.min_delta = min_delta
        self.wait = 0
        self.max_wait = 0
        self.ind = 0
        if mode not in ['auto', 'min', 'max']:
            logging.warning('EarlyStopping mode %s is unknown, '
                            'fallback to auto mode.', mode, RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1

        self.best = np.Inf if self.monitor_op == np.less else -np.Inf

    def begin(self):
        # Convert names to tensors if given
        graph = tf.get_default_graph()
        self.monitor = graph.as_graph_element(self.monitor)
        if isinstance(self.monitor, tf.Operation):
            self.monitor = self.monitor.outputs[0]

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return session_run_hook.SessionRunArgs(self.monitor)

    def after_run(self, run_context, run_values):
        self.ind += 1

        current = run_values.results

        if self.ind % 200 == 0:
          print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            if self.max_wait < self.wait:
              self.max_wait = self.wait
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                run_context.request_stop()

And I use the class like this:我像这样使用这个类:


early_stopping_hook = EarlyStoppingHook(
      monitor='total_loss', 
      patience=2000)

train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])

What I don't understand is what is total_loss?我不明白的是 total_loss 是什么? Is this val loss or train loss?这是 val 损失还是 train 损失? Also I don't understand where these losses ('total_loss', 'loss_1', 'loss_2') are defined.我也不明白这些损失('total_loss'、'loss_1'、'loss_2')是在哪里定义的。

So, here is what worked for me所以,这对我有用

from matplotlib import pyplot as plt
import numpy as np

import collections
import os 

_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'

def _summaries(eval_dir):
  """Yields `tensorflow.Event` protos from event files in the eval dir.
  Args:
    eval_dir: Directory containing summary files with eval metrics.
  Yields:
    `tensorflow.Event` object read from the event files.
  """
  if tf.compat.v1.gfile.Exists(eval_dir):
    for event_file in tf.compat.v1.gfile.Glob(
        os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
      for event in tf.compat.v1.train.summary_iterator(event_file):
        yield event

def read_eval_metrics(eval_dir):
  """Helper to read eval metrics from eval summary files.
  Args:
    eval_dir: Directory containing summary files with eval metrics.
  Returns:
    A `dict` with global steps mapping to `dict` of metric names and values.
  """
  eval_metrics_dict = collections.defaultdict(dict)
  for event in _summaries(eval_dir):
    if not event.HasField('summary'):
      continue
    metrics = {}
    for value in event.summary.value:
      if value.HasField('simple_value'):
        metrics[value.tag] = value.simple_value
    if metrics:
      eval_metrics_dict[event.step].update(metrics)
  return collections.OrderedDict(
      sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
  
met_dict_2 = read_eval_metrics('/content/gdrive2/My Drive/models/retinanet/eval_0')
x = []
y = []
for k, v in met_dict_2.items():
    x.append(k)
    y.append(v['Loss/total_loss'])

read_eval_metrics function returns dictionary which keys are iteration number and values are different metrics and losses computer at that evaluation step. read_eval_metrics 函数返回字典,其中键是迭代次数,值是该评估步骤的不同度量和损失计算机。 But you can also use this function for train event files.但是您也可以将此功能用于训练事件文件。 You just need to change the path.你只需要改变路径。

Example of one key value pair from returned dictionary.返回字典中的一个键值对示例。

(4988, {'DetectionBoxes_Precision/Precision@.50IOU': 0.12053315341472626,
               'DetectionBoxes_Precision/mAP': 0.060865387320518494,
               'DetectionBoxes_Precision/mAP (large)': 0.07213596999645233,
               'DetectionBoxes_Precision/mAP (medium)': 0.062120337039232254,
               'DetectionBoxes_Precision/mAP (small)': 0.02642354555428028,
               'DetectionBoxes_Precision/mAP@.50IOU': 0.11469704657793045,
               'DetectionBoxes_Precision/mAP@.75IOU': 0.06001879647374153,
               'DetectionBoxes_Recall/AR@1': 0.13470394909381866,
               'DetectionBoxes_Recall/AR@10': 0.20102562010288239,
               'DetectionBoxes_Recall/AR@100': 0.2040158212184906,
               'DetectionBoxes_Recall/AR@100 (large)': 0.2639017701148987,
               'DetectionBoxes_Recall/AR@100 (medium)': 0.20173722505569458,
               'DetectionBoxes_Recall/AR@100 (small)': 0.10018187761306763,
               'Loss/classification_loss': 1.0127471685409546,
               'Loss/localization_loss': 0.3542810380458832,
               'Loss/regularization_loss': 0.708609938621521,
               'Loss/total_loss': 2.0756208896636963,
               'learning_rate': 0.0006235376931726933,
               'loss': 2.0756208896636963})

So I ended up setting monitor argument to 'Loss/total_loss' instead of 'total_loss' in EarlyStoppingHook.所以我最终在 EarlyStoppingHook 中将监视器参数设置为 'Loss/total_loss' 而不是 'total_loss'。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM