简体   繁体   中英

Custom loss function in Keras, Python?

Let's assume we have the predicted output vector:

y_pred = [1, 0, 0, 1]

and the real output values:

y_true = [0, 1, 0, 0]

I want to build the following difference vector y_pred-y_true:

y_diff = [1, -1, 0, 1]

counts the number of 1s in it and multiplying it by a constant. This should be the result of my custom loss functions. The goal is to give more importance to some kind of errors (in this case, I want bigger losses if the predicted value was 0 while the true was 1).

This is my implementation attempt:

def custom_loss_function(y_true, y_pred):
    # if it was 1 and you wrote 0, error is very very big
    y_diff = tf.math.subtract(y_true, y_pred)

    def fn(elem):
        if elem == 1:
            return 10
        elif elem == -1:
            return 1
        else:
            return 0

    return tf.reduce_sum(tf.map_fn(fn, y_diff))

The problem is that in this way my loss function will not be "differentiable". I think this is the reason why I get the error:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

Any idea about how to implement a custom loss function giving bigger (or less) losses according to some conditions, like in the current task?

Your question is contradictory. You say you want y_pred - y_true but you compute y_true - y_pred in the code. Nevertheless you can use the following.

def custom_loss_function(y_true, y_pred):
    # if it was 1 and you wrote 0, error is very very big
    y_diff = y_true - y_pred

    mul_mask = tf.cast(tf.math.equal(y_diff, 1.0), tf.float32)*9.0 + 1
    y_diff = tf.math.sqrt((y_diff * mul_mask)**2)

    return tf.reduce_sum(y_diff)

PS : I'm hoping you have a good reason behind using this custom loss function. Because you can do the weighing simply using the class_weights argument when you do model.fit() , thus, there's no need to implement this yourself if you just want weighing per class.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM