简体   繁体   中英

How to create an “exact match” eval_metric_op for TensorFlow?

I am trying the create an eval_metric_op function that will display the proportion of exact matches at a given threshold for a multi-label classification problem. The following function returns 0 (no exact match) or 1 (exact match) based on the threshold given.

def exact_match(y_true, y_logits, threshold):
    y_pred = np.round(y_logits-threshold+0.5)
    return int(np.array_equal(y_true, y_pred))

y_true = np.array([1,1,0])
y_logits = np.array([0.67, 0.9, 0.55])

print(exact_match(y_true, y_logits, 0.5))
print(exact_match(y_true, y_logits, 0.6))

A threshold of 0.5 yields a prediction of [1,1,1] which is incorrect so the function returns 0. A threshold of 0.6 yields a prediction of [1,1,0] which is correct so the function returns 1.

I would like to turn this function into a tensorflow eval metric op -- can anybody advise the best way to do this?

I can get to the same logic using tensorflow ops below, but I'm not entirely sure how to make this into a custom eval_metric_op:

import tensorflow as tf

def exact_match_fn(y_true, y_logits, threshold):
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true))
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
    pred_match = tf.equal(predictions, tf.round(y_true))
    exact_match = tf.reduce_min(tf.to_float(pred_match))
    return exact_match

graph = tf.Graph()
with graph.as_default():
    y_true = tf.constant([1,1,0], dtype=tf.float32)
    y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32)
    exact_match_50 = exact_match_fn(y_true, y_logits, 0.5)
    exact_match_60 = exact_match_fn(y_true, y_logits, 0.6)

sess = tf.InteractiveSession(graph=graph)
print(sess.run([exact_match_50, exact_match_60]))

The above code will result in exact_match_50 of 0 (at least 1 prediction incorrect) and exact_match_60 of 1 (all labels correct).

Is it sufficient to simply use tf.contrib.metrics.streaming_mean() or is there a better alternative? I would implement this as:

tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold))

The output of your exact_match_fn is an op which can be used for evaluation. If you want the average over a batch, change your reduce_min to just reduce over the relevant axis.

Eg if you y_true / y_logits each have shape (batch_size, n)

def exact_match_fn(y_true, y_logits, threshold):
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true))
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
    pred_match = tf.equal(predictions, tf.round(y_true))
    exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1)
    return exact_match


def exact_match_prop_fn(*args):
    return tf.reduce_mean(exact_match_fn(*args))

This will give you the average over a batch. If you want the average over an entire dataset, I'd just collect the matches (or correct and total counts) and evaluate outside of the session/tensorflow, but streaming_mean probably does just that, not sure.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM