简体   繁体   English

在Keras使用tf.metrics?

[英]Use tf.metrics in Keras?

I'm especially interested in specificity_at_sensitivity . 我对specificity_at_sensitivity特别感兴趣。 Looking through the Keras docs : 浏览Keras文档

from keras import metrics

model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae, metrics.categorical_accuracy])

But it looks like the metrics list must have functions of arity 2, accepting (y_true, y_pred) and returning a single tensor value. 但看起来metrics列表必须具有arity 2的函数,接受(y_true, y_pred)并返回单个张量值。


EDIT: Currently here is how I do things: 编辑:目前这是我做的事情:

from sklearn.metrics import confusion_matrix

predictions = model.predict(x_test)
y_test = np.argmax(y_test, axis=-1)
predictions = np.argmax(predictions, axis=-1)
c = confusion_matrix(y_test, predictions)
print('Confusion matrix:\n', c)
print('sensitivity', c[0, 0] / (c[0, 1] + c[0, 0]))
print('specificity', c[1, 1] / (c[1, 1] + c[1, 0]))

The disadvantage of this approach, is I only get the output I care about when training has finished. 这种方法的缺点是,我只能在训练结束时得到我关心的输出。 Would prefer to get metrics every 10 epochs or so. 宁愿每10个纪元左右获得指标。

I've found a related issue on github , and it seems that tf.metrics are still not supported by Keras models. 我在github上发现了一个相关的问题,似乎tf.metrics模型仍然不支持tf.metrics However, in case you are very interested in using tf.metrics.specificity_at_sensitivity , I would suggest the following workaround (inspired by BogdanRuzh's solution): 但是,如果您对使用tf.metrics.specificity_at_sensitivity非常感兴趣,我建议采用以下解决方法(受BogdanRuzh解决方案的启发):

def specificity_at_sensitivity(sensitivity, **kwargs):
    def metric(labels, predictions):
        # any tensorflow metric
        value, update_op = tf.metrics.specificity_at_sensitivity(labels, predictions, sensitivity, **kwargs)

        # find all variables created for this metric
        metric_vars = [i for i in tf.local_variables() if 'specificity_at_sensitivity' in i.name.split('/')[2]]

        # Add metric variables to GLOBAL_VARIABLES collection.
        # They will be initialized for new session.
        for v in metric_vars:
            tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)

        # force to update metric values
        with tf.control_dependencies([update_op]):
            value = tf.identity(value)
            return value
    return metric


model.compile(loss='mean_squared_error',
              optimizer='sgd',
              metrics=[metrics.mae,
                       metrics.categorical_accuracy,
                       specificity_at_sensitivity(0.5)])

UPDATE: 更新:

You can use model.evaluate to retrieve the metrics after training. 您可以使用model.evaluate在培训后检索指标。

I don't think there is a strict limit to only two incoming arguments, in metrics.py the function is just three incoming arguments, but k selects the default value of 5. 我不认为只有两个传入的参数有严格的限制,在metrics.py中,函数只是三个传入参数,但k选择默认值5。

def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
    return K.mean(K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 tf.metrics的返回值是什么意思? - What do the return values of tf.metrics mean? Tensorflow:如何在多类分类中使用 tf.keras.metrics? - Tensorflow: How to use tf.keras.metrics in multiclass classification? 带有 TF 后端的 Keras 指标与 tensorflow 指标 - Keras metrics with TF backend vs tensorflow metrics 在 tf.keras.metrics 中使用不同的指标进行多分类模型 - Use different metrics in tf.keras.metrics for mutli-classification model 使用度量“tf.keras.metrics.AUC”进行训练后,如何在 Keras 中使用“load_model”? - How to use "load_model" in Keras after trained using as metric "tf.keras.metrics.AUC"? 自定义 tf.keras 指标意外结果 - Custom tf.keras metrics unexpected result 使用在 tf.keras 中实现的自定义指标加载 keras model - Loading keras model with custom metrics implemented in tf.keras 何时在分类器神经网络 model 中使用“准确度”字符串或 tf.keras.metrics.Accuracy() - When to use "accuracy" string or tf.keras.metrics.Accuracy() in a classifier neural network model 可以传递给 tf.keras.model.compile 的指标列表 - List of metrics that can be passed to tf.keras.model.compile 了解 tf.keras.metrics.Precision and Recall 进行多类分类 - Understanding tf.keras.metrics.Precision and Recall for multiclass classification
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM