简体   繁体   English

自定义 TensorFlow 指标:给定假阳性率下的真阳性率

[英]Custom TensorFlow metric: true positive rate at given false positive rate

I have a binary classification problem with categories background (bg) = 0, signal (sig) = 1, for which I am training NNs.我有一个二元分类问题,类别背景 (bg) = 0,信号 (sig) = 1,我正在为此训练 NN。 For monitoring purposes, I am trying to implement a custom metric in Keras with TensorFlow backend that does the following:出于监控目的,我正在尝试使用 TensorFlow 后端在 Keras 中实现自定义指标,该指标执行以下操作:

1) Calculate the threshold on my NN output which would result in a false positive rate (classifying bg as signal) of X (in this case X = 0.02, but it could be anything). 1) 计算我的 NN 输出的阈值,这将导致 X 的误报率(将 bg 分类为信号)(在这种情况下 X = 0.02,但它可以是任何东西)。

2) Calculate the true positive rate at this threshold. 2)计算这个阈值的真阳性率。

Given numpy arrays y_true, y_pred, I would write a function like:给定 numpy 数组 y_true, y_pred,我会写一个函数,如:

def eff_at_2percent_metric(y_true, y_pred):
    #Find list of bg events
    bg_list = np.argwhere(y_true < 0.5)
    #Order by the NN output
    ordered_bg_predictions = np.flip(np.sort(y_pred[bg_list]),axis=0)
    #Find the threshold with 2% false positive rate
    threshold = ordered_bg_predictions[0.02*round(len(ordered_bg_list))]

    #Find list of signal events
    sig_list = np.argwhere(y_true > 0.5)
    #Order these by NN output
    ordered_sig_predictions = np.sort(y_pred[sig_list])
    #Find true positive rate with this threshold
    sig_eff = 1 - np.searchsorted(ordered_sig_predictions,threshold)/len(ordered_sig_predictions)

    return sig_eff

Of course, this does not work because to implement a custom metric, y_true and y_pred are supposed to be TensorFlow tensors rather than numpy arrays.当然,这不起作用,因为要实现自定义指标,y_true 和 y_pred 应该是 TensorFlow 张量而不是 numpy 数组。 Is there any way I can make this work correctly?有什么办法可以使这项工作正常进行吗?

有一个针对特异性的敏感性指标,我认为它是等效的(特异性是 1 减去 FPR)。

You can implement your own metric, and here is an example for the false positive rate:您可以实现自己的指标,以下是误报率的示例:

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.metrics_impl import _aggregate_across_towers
from tensorflow.python.ops.metrics_impl import true_negatives
from tensorflow.python.ops.metrics_impl import false_positives
from tensorflow.python.ops.metrics_impl import _remove_squeezable_dimensions

def false_positive_rate(labels,                                                             
               predictions,                                                     
               weights=None,                                                       
               metrics_collections=None,                                           
               updates_collections=None,                                           
               name=None):                                                         
  if context.executing_eagerly():                                                  
    raise RuntimeError('tf.metrics.recall is not supported is not '                
                       'supported when eager execution is enabled.')               
                                                                                   
  with variable_scope.variable_scope(name, 'false_alarm',                          
                                     (predictions, labels, weights)):           
    predictions, labels, weights = _remove_squeezable_dimensions(                  
        predictions=math_ops.cast(predictions, dtype=dtypes.bool),                 
        labels=math_ops.cast(labels, dtype=dtypes.bool),                           
        weights=weights)                                                           
                                                                                   
    false_p, false_positives_update_op = false_positives(                          
        labels,                                                                    
        predictions,                                                            
        weights,                                                                   
        metrics_collections=None,                                                  
        updates_collections=None,                                                  
        name=None)                                                                 
    true_n, true_negatives_update_op = true_negatives(                          
        labels,                                                                    
        predictions,                                                               
        weights,                                                                   
        metrics_collections=None,                                                  
        updates_collections=None,                                                  
        name=None)                                                              
                                                                                   
    def compute_false_positive_rate(true_n, false_p, name):                                        
      return array_ops.where(                                                      
          math_ops.greater(true_n + false_p, 0),                                   
          math_ops.div(false_p, true_n + false_p), 0, name)                        
                                                                                   
    def once_across_towers(_, true_n, false_p):                                 
      return compute_false_positive_rate(true_n, false_p, 'value')                              
                                                                                   
    false_positive_rate = _aggregate_across_towers(                                                
        metrics_collections, once_across_towers, true_n, false_p)                  
                                                                                   
    update_op = compute_false_positive_rate(true_negatives_update_op,                              
                               false_positives_update_op, 'update_op')             
    if updates_collections:                                                        
      ops.add_to_collections(updates_collections, update_op)                       
                                                                                   
    return false_positive_rate, update_op

You can adapt the code to the true positive rate.您可以将代码调整为真阳性率。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM