简体   繁体   中英

TensorFlow custom training step with different loss functions

Background

According to the TensorFlow documentation , a custom training step can be performed with the following

# Fake sample data for testing
x_batch_train = tf.zeros([32, 3, 1], dtype="float32")
y_batch_train = tf.zeros([32], dtype="float32")
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
with tf.GradientTape() as tape:
    logits = model(x_batch_train, training=True)
    loss_value = loss_fn(y_batch_train, logits)

grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

But if I want to use a different loss function like categorical cross-entropy I would need to argmax the logits created in the gradient tape:

loss_fn = tf.keras.lossees.get("categorical_crossentropy")
with tf.GradientTape() as tape:
    logits = model(x_batch_train, training=True)
    prediction = tf.cast(tf.argmax(logits, axis=-1), y_batch_train.dtype)
    loss_value = loss_fn(y_batch_train, prediction)

grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

Problem

The problem with this is that the tf.argmax function is not differentiable, so TensorFlow wouldn't be able to compute the gradients and you would get the error:

ValueError: No gradients provided for any variable: [...]

My question: Without changing the loss function how could I make the second example work?

categorical_crossentropy expect your labels to be one hot encoded, so you should make sure of that first. Then pass directly the result of your model, this output should be one probability per category more info -> https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy#standalone_usage

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM