繁体   English   中英

Tensorflow 2 ModelCheckpoint 回调与多类召回自定义指标

[英]Tensorflow 2 ModelCheckpoint callback with multiclass recall custom metric

我正在为多类分类任务(num_classes=7)构建一个 CNN 分类器。 由于不平衡和主题领域,我对这项任务的目标指标是跨类的宏观平均召回率。

当 model 训练时,如果验证多类宏召回被评估为高于以前在整个 epoch 中看到的最高值,我想通过在每个 epoch 结束时保存 model 来检查它。 我相信这将分两个阶段进行:

  1. 创建一个自定义指标,用于计算每个 epoch 结束时验证数据的多类场景的类之间的平均召回率
  2. 创建一个 ModelCheckpoint 回调来跟踪自定义指标并在 model 超过之前的最大值时保存它。

有人会有这样或类似的例子吗? 我对宏平均多类召回的自定义指标的实现更感兴趣,因为我相信一旦在 model.compile() 中定义了这个指标,回调就可以轻松完成

我通过使用这篇文章并进行了一些调整来实现自定义指标,例如计算运行平均值 以下是自定义指标的代码:

import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric

class MacroAverageRecall( Metric ):
    """Custom metric for calculating multiclass recall during         
training"""
    def __init__(self,
                 num_classes,
                 batch_size,
                 name='multiclass_recall',
                 **kwargs):
        super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.num_batches = 0
        self.average_recall = self.add_weight( name="recall", initializer="zeros" )

    def update_state(self, y_true, y_pred, sample_weight=None):
        recall = 0
        pred = K.argmax( y_pred, axis=-1 )
        true = K.argmax( y_true, axis=-1 )

        for i in range( self.num_classes ):
            # Find where the pred equals the class
            predicted_instances_bool = K.equal(
                pred,
                i
            )
            # Find where the labels equals the class
            true_instances_bool = K.equal(
                true,
                i
            )
            # Converting tensors of bools to int (1,0)
            predicted_instances = K.cast(
                predicted_instances_bool,
                'float32'
            )
            true_instances = K.cast(
                true_instances_bool,
                'float32'
            )
            # Reshaping tensors
            true_reshaped = K.reshape(
                true_instances,
                (1, -1)
            )
            predicted_reshaped = K.reshape(
                predicted_instances,
                (-1, 1)
            )
            # Find true positives
            true_positives = K.dot(
                true_reshaped,
                predicted_reshaped
            )
            # Compute the true positive
            pred_true_pos = K.sum(
                true_positives
            )
            # divide by all positives in t
            all_true_positives = (K.sum( true_instances ) + K.epsilon())
            class_recall = pred_true_pos / all_true_positives
            recall += class_recall

        self.num_batches += 1
        avg_recall = recall / self.num_classes
        recall_update = (avg_recall - self.average_recall) / self.num_batches
        self.average_recall.assign_add( recall_update )

    def result(self):
        return self.average_recall

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.average_recall.assign( 0. )

以及 model 训练期间使用的检查点:

callbacks.ModelCheckpoint(
            filepath=os.path.join(
                self._metadata['checkpoint_directory'],
                f'checkpoint-{self._metadata["create_time"]}.h5' ),
            save_best_only=True if self._val else False,
            monitor='val_multiclass_recall',
            mode='max',
            verbose=1 )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM