[英]Tensorflow 2.0: Custom metric (balanced accuracy score) for modelcheckpoint not working
[英]Tensorflow 2 ModelCheckpoint callback with multiclass recall custom metric
我正在为多类分类任务(num_classes=7)构建一个 CNN 分类器。 由于不平衡和主题领域,我对这项任务的目标指标是跨类的宏观平均召回率。
当 model 训练时,如果验证多类宏召回被评估为高于以前在整个 epoch 中看到的最高值,我想通过在每个 epoch 结束时保存 model 来检查它。 我相信这将分两个阶段进行:
有人会有这样或类似的例子吗? 我对宏平均多类召回的自定义指标的实现更感兴趣,因为我相信一旦在 model.compile() 中定义了这个指标,回调就可以轻松完成
我通过使用这篇文章并进行了一些调整来实现自定义指标,例如计算运行平均值。 以下是自定义指标的代码:
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric
class MacroAverageRecall( Metric ):
"""Custom metric for calculating multiclass recall during
training"""
def __init__(self,
num_classes,
batch_size,
name='multiclass_recall',
**kwargs):
super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
self.batch_size = batch_size
self.num_classes = num_classes
self.num_batches = 0
self.average_recall = self.add_weight( name="recall", initializer="zeros" )
def update_state(self, y_true, y_pred, sample_weight=None):
recall = 0
pred = K.argmax( y_pred, axis=-1 )
true = K.argmax( y_true, axis=-1 )
for i in range( self.num_classes ):
# Find where the pred equals the class
predicted_instances_bool = K.equal(
pred,
i
)
# Find where the labels equals the class
true_instances_bool = K.equal(
true,
i
)
# Converting tensors of bools to int (1,0)
predicted_instances = K.cast(
predicted_instances_bool,
'float32'
)
true_instances = K.cast(
true_instances_bool,
'float32'
)
# Reshaping tensors
true_reshaped = K.reshape(
true_instances,
(1, -1)
)
predicted_reshaped = K.reshape(
predicted_instances,
(-1, 1)
)
# Find true positives
true_positives = K.dot(
true_reshaped,
predicted_reshaped
)
# Compute the true positive
pred_true_pos = K.sum(
true_positives
)
# divide by all positives in t
all_true_positives = (K.sum( true_instances ) + K.epsilon())
class_recall = pred_true_pos / all_true_positives
recall += class_recall
self.num_batches += 1
avg_recall = recall / self.num_classes
recall_update = (avg_recall - self.average_recall) / self.num_batches
self.average_recall.assign_add( recall_update )
def result(self):
return self.average_recall
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.average_recall.assign( 0. )
以及 model 训练期间使用的检查点:
callbacks.ModelCheckpoint(
filepath=os.path.join(
self._metadata['checkpoint_directory'],
f'checkpoint-{self._metadata["create_time"]}.h5' ),
save_best_only=True if self._val else False,
monitor='val_multiclass_recall',
mode='max',
verbose=1 )
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.