简体   繁体   中英

Keras loss function understanding

In order to understand some callbacks of Keras better, I want to artificially create a nan loss.

This is the function

def soft_dice_loss(y_true, y_pred):

  from keras import backend as K
  if K.eval(K.random_normal((1, 1), mean=2, stddev=2))[0][0] // 1 == 2.0:
    # return nan
    return K.exp(1.0) / K.exp(-10000000000.0) - K.exp(1.0) / K.exp(-10000000000.0)

  epsilon = 1e-6

  axes = tuple(range(1, len(y_pred.shape) - 1))
  numerator = 2. * K.sum(y_pred * y_true, axes)
  denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)

 return 1 - K.mean(numerator / (denominator + epsilon))

So normally, it calculates the dice loss, but from time to time it should randomly return a nan . However, this does not seem to happen:

keras输出

From time to time though, when I try to run the code, it stops right at the start (before the first epoch) with an error, saying that An operation has None for gradient. Please make sure that all of your ops have a gradient defined An operation has None for gradient. Please make sure that all of your ops have a gradient defined

Does that mean, that the the random function of Keras is just evaluated once and then always returns the same value? If so, why is that and how can I create a loss function that returns nan from time to time?

Your first conditional statement is only evaluated once the loss function is defined (ie called; that is why Keras stops right at the start). Instead, you could use keras.backend.switch to integrate your conditional into the graph's logic. Your loss function could be something along the lines of:

import keras.backend as K
import numpy as np


def soft_dice_loss(y_true, y_pred):
    epsilon = 1e-6
    axes = tuple(range(1, len(y_pred.shape) - 1))
    numerator = 2. * K.sum(y_pred * y_true, axes)
    denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
    loss = 1 - K.mean(numerator / (denominator + epsilon))

    return K.switch(condition=K.random_normal((), mean=0, stddev=1) > 3,
                    then_expression=K.variable(np.nan),
                    else_expression=loss)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM