简体   繁体   中英

TensorFlow 2.0: What is the difference between sparse_categorical_crossentropy and SparseCategoricalCrossentropy?

Reading the the docs of TensorFlow 2.0, I found:

tf.keras.losses.sparse_categorical_crossentropy

and

tf.keras.losses.SparseCategoricalCrossentropy

the way they are used in tutorials, their arguments, their descriptions, they seem equal to me. What is the difference between the two?

There is none. If you look at the documentation you linked, you can get to the source code on GitHub. Both point to the same object:

def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
  """Categorical crossentropy with integer targets.

  if not from_logits:
    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
        output.op.type != 'Softmax'):
      epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
      output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
      output = math_ops.log(output)
      # ... blablabla

Which is in:

tensorflow/python/keras/backend.py

For instance, the former ( tf.keras.losses.sparse_categorical_crossentropy ) is called like this:

def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
  return K.sparse_categorical_crossentropy(
      y_true, y_pred, from_logits=from_logits, axis=axis)

So it points to the other one in tensorflow/python/keras/backend.py

One is a function and one is a class.

First one is the functional version, it will just spit out the value of the loss when you evaluate it.

The second one is the class version. You need to evaluate instance of the class itself in order to get the loss value.

I believe you are right that there is little difference, if using the keras api the only difference will be when you compile the model.

eg

model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy)

vs

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy())

Notice the extra brackets on the class version, you need an instance of a class to be passed in.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM