[英]early stopping in tensorflow object detection api
I am trying to implement early stopping in TF OD API.我正在尝试在 TF OD API 中实现提前停止。 I used this code .我使用了这个代码。
Here is my EarlyStoppingHook (is it essentially just a copy from the above code):这是我的 EarlyStoppingHook(它本质上只是上述代码的副本):
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
"""
"""
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
self.max_wait = 0
self.ind = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]
def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
self.ind += 1
current = run_values.results
if self.ind % 200 == 0:
print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
if self.max_wait < self.wait:
self.max_wait = self.wait
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()
And I use the class like this:我像这样使用这个类:
early_stopping_hook = EarlyStoppingHook(
monitor='total_loss',
patience=2000)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])
What I don't understand is what is total_loss?我不明白的是 total_loss 是什么? Is this val loss or train loss?这是 val 损失还是 train 损失? Also I don't understand where these losses ('total_loss', 'loss_1', 'loss_2') are defined.我也不明白这些损失('total_loss'、'loss_1'、'loss_2')是在哪里定义的。
So, here is what worked for me所以,这对我有用
from matplotlib import pyplot as plt
import numpy as np
import collections
import os
_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
def _summaries(eval_dir):
"""Yields `tensorflow.Event` protos from event files in the eval dir.
Args:
eval_dir: Directory containing summary files with eval metrics.
Yields:
`tensorflow.Event` object read from the event files.
"""
if tf.compat.v1.gfile.Exists(eval_dir):
for event_file in tf.compat.v1.gfile.Glob(
os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
for event in tf.compat.v1.train.summary_iterator(event_file):
yield event
def read_eval_metrics(eval_dir):
"""Helper to read eval metrics from eval summary files.
Args:
eval_dir: Directory containing summary files with eval metrics.
Returns:
A `dict` with global steps mapping to `dict` of metric names and values.
"""
eval_metrics_dict = collections.defaultdict(dict)
for event in _summaries(eval_dir):
if not event.HasField('summary'):
continue
metrics = {}
for value in event.summary.value:
if value.HasField('simple_value'):
metrics[value.tag] = value.simple_value
if metrics:
eval_metrics_dict[event.step].update(metrics)
return collections.OrderedDict(
sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
met_dict_2 = read_eval_metrics('/content/gdrive2/My Drive/models/retinanet/eval_0')
x = []
y = []
for k, v in met_dict_2.items():
x.append(k)
y.append(v['Loss/total_loss'])
read_eval_metrics function returns dictionary which keys are iteration number and values are different metrics and losses computer at that evaluation step. read_eval_metrics 函数返回字典,其中键是迭代次数,值是该评估步骤的不同度量和损失计算机。 But you can also use this function for train event files.但是您也可以将此功能用于训练事件文件。 You just need to change the path.你只需要改变路径。
Example of one key value pair from returned dictionary.返回字典中的一个键值对示例。
(4988, {'DetectionBoxes_Precision/Precision@.50IOU': 0.12053315341472626,
'DetectionBoxes_Precision/mAP': 0.060865387320518494,
'DetectionBoxes_Precision/mAP (large)': 0.07213596999645233,
'DetectionBoxes_Precision/mAP (medium)': 0.062120337039232254,
'DetectionBoxes_Precision/mAP (small)': 0.02642354555428028,
'DetectionBoxes_Precision/mAP@.50IOU': 0.11469704657793045,
'DetectionBoxes_Precision/mAP@.75IOU': 0.06001879647374153,
'DetectionBoxes_Recall/AR@1': 0.13470394909381866,
'DetectionBoxes_Recall/AR@10': 0.20102562010288239,
'DetectionBoxes_Recall/AR@100': 0.2040158212184906,
'DetectionBoxes_Recall/AR@100 (large)': 0.2639017701148987,
'DetectionBoxes_Recall/AR@100 (medium)': 0.20173722505569458,
'DetectionBoxes_Recall/AR@100 (small)': 0.10018187761306763,
'Loss/classification_loss': 1.0127471685409546,
'Loss/localization_loss': 0.3542810380458832,
'Loss/regularization_loss': 0.708609938621521,
'Loss/total_loss': 2.0756208896636963,
'learning_rate': 0.0006235376931726933,
'loss': 2.0756208896636963})
So I ended up setting monitor argument to 'Loss/total_loss' instead of 'total_loss' in EarlyStoppingHook.所以我最终在 EarlyStoppingHook 中将监视器参数设置为 'Loss/total_loss' 而不是 'total_loss'。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.