简体   繁体   中英

Sparse Categorical CrossEntropy causing NAN loss

So, I've been trying to implement a few custom losses, and so thought I'd start off with implementing SCE loss, without using the built in TF object. Here's the function I wrote for it.

def custom_loss(y_true, y_pred):
    print(y_true, y_pred)
    return tf.cast(tf.math.multiply(tf.experimental.numpy.log2(y_pred[y_true[0]]), -1), dtype=tf.float32)

y_pred is the set of probabilties, and y_true is the index of the correct one. This setup should work according to all that I've read, but it returns NAN loss.

I checked if there's a problem with the training loop, but it works prefectly with the builtin losses.

Could someone tell me what the problem is with this code?

You can replicate the SparseCategoricalCrossentropy() loss function as follows

import tensorflow as tf

def sparse_categorical_crossentropy(y_true, y_pred, clip=True):

    y_true = tf.convert_to_tensor(y_true, dtype=tf.int32)
    y_pred = tf.convert_to_tensor(y_pred, dtype=tf.float32)

    y_true = tf.one_hot(y_true, depth=y_pred.shape[1])

    if clip == True:
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)

    return - tf.reduce_mean(tf.math.log(y_pred[y_true == 1]))

Note that the SparseCategoricalCrossentropy() loss function applies a small offset ( 1e-7 ) to the predicted probabilities in order to make sure that the loss values are always finite, see also this question .

y_true = [1, 2]
y_pred = [[0.05, 0.95, 0.0], [0.1, 0.8, 0.1]]

print(tf.keras.losses.SparseCategoricalCrossentropy()(y_true, y_pred).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=True).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=False).numpy())
# 1.1769392
# 1.1769392
# 1.1769392

y_true = [1, 2]
y_pred = [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]

print(tf.keras.losses.SparseCategoricalCrossentropy()(y_true, y_pred).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=True).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=False).numpy())
# 8.059048
# 8.059048
# inf

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM