简体   繁体   中英

How create custom metric for Tensorflow optimizer?

I want minimize/maximize such metrics as F1-score, Precision, Recall and my custom metric. There is my metric and optimizer code:

def my_metric(logits, labels):
    predicted = tf.argmax(logits, 1)
    actual = tf.argmax(labels, 1)

    NS = tf.count_nonzero(actual)
    NR = tf.reduce_sum(tf.cast(tf.equal(actual, 0), tf.float32))
    TP = tf.reduce_sum(tf.cast(tf.equal(actual+predicted, 0), tf.float32))
    FP = tf.reduce_sum(tf.cast(tf.equal(actual*(1-predicted), 1), tf.float32))
    TN = tf.reduce_sum(tf.cast(tf.equal(actual+predicted, 2), tf.float32))
    FN = tf.reduce_sum(tf.cast(tf.equal(actual+(1-predicted), 0), tf.float32))
    '''
    Precision = TP / TP + FP
    Recall = TP / TP + FN
    b = 0.5
    denom = (1.0 + b**2) * TP + FN*b**2 + FP
    Fb = (1.0 + b**2) * TP / denom
    '''
    Metric = (TP / NR) - (FP / NS)

    return Metric


def training(metric, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate)
    train_op = optimizer.minimize(metric)
    return train_op

When i try to minimize any metric, i get such error:

ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables [...] and loss Tensor("Training/Sub_3:0", shape=(), dtype=float32).

What should i do to train my neural network using some custom metric instead loss function? Maybe add some gradient definition? How to do it for metrics above?

The metric has to be differentiable to your parameters. The tensorflow method tf.equal is not differentiable.

If you are not sure whether an operation is differentiable wrt your parameters, you can find out using tf.gradients method.

import tensorflow as tf

w = tf.Variable(1, name="w", dtype=tf.float32 ) # parameter to optimize for
x = tf.placeholder(shape=(), dtype=tf.float32, name="x") # input
op = tf.multiply(w,x)

grads_op_wrt_w = tf.gradients(op, w)
print(grads_op_wrt_w)

I have create a small gist for a method which checks the gradient flow of an operation here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM