[英]Tensorflow metrics with custom Estimator
我有一个卷积神经网络,我最近重构使用Tensorflow的Estimator API,很大程度上遵循本教程 。 但是,在训练期间,我添加到EstimatorSpec的度量标准没有显示在Tensorboard上,并且似乎没有在tfdbg中进行评估,尽管名称范围和度量标准存在于写入Tensorboard的图表中。
model_fn
的相关位如下:
...
predictions = tf.placeholder(tf.float32, [num_classes], name="predictions")
...
with tf.name_scope("metrics"):
predictions_rounded = tf.round(predictions)
accuracy = tf.metrics.accuracy(input_y, predictions_rounded, name='accuracy')
precision = tf.metrics.precision(input_y, predictions_rounded, name='precision')
recall = tf.metrics.recall(input_y, predictions_rounded, name='recall')
if mode == tf.estimator.ModeKeys.PREDICT:
spec = tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions)
elif mode == tf.estimator.ModeKeys.TRAIN:
...
# if we're doing softmax vs sigmoid, we have different metrics
if cross_entropy == CrossEntropyType.SOFTMAX:
metrics = {
'accuracy': accuracy,
'precision': precision,
'recall': recall
}
elif cross_entropy == CrossEntropyType.SIGMOID:
metrics = {
'precision': precision,
'recall': recall
}
else:
raise NotImplementedError("Unrecognized cross entropy function: {}\t Available types are: SOFTMAX, SIGMOID".format(cross_entropy))
spec = tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
else:
raise NotImplementedError('ModeKey provided is not supported: {}'.format(mode))
return spec
任何人都有任何想法为什么这些没有写? 我正在使用Tensorflow 1.7和Python 3.5。 我试过通过tf.summary.scalar
明确地添加它们,虽然它们确实以这种方式进入Tensorboard,但是在第一次通过图形之后它们永远不会更新。
metrics API有一个扭曲,让我们以tf.metrics.accuracy
为例(所有tf.metrics.*
工作相同)。 这将返回2个值, accuracy
指标和upate_op
,这看起来像是您的第一个错误。 你应该有这样的东西:
accuracy, update_op = tf.metrics.accuracy(input_y, predictions_rounded, name='accuracy')
accuracy
只是您期望计算的值,但请注意,您可能希望在多次调用sess.run
时计算准确性,例如,当您计算不完全适合的大型测试集的准确性时记忆。 这就是update_op
用武之地,它会产生结果,因此当你要求accuracy
它会给你一个运行记录。
update_op
没有依赖项,因此您需要在sess.run
显式运行它或添加依赖项。 例如,您可以将其设置为依赖于成本函数,以便在计算成本函数时计算update_op
(导致运行计数以更新准确性):
with tf.control_dependencies(cost):
tf.group(update_op, other_update_ops, ...)
您可以使用局部变量初始值设定项重置度量标准的值:
sess.run(tf.local_variables_initializer())
您需要使用tf.summary.scalar(accuracy)
向tensorboard添加精度,如您所提到的那样(尽管看起来您添加了错误的东西)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.