简体   繁体   中英

Why does keras model predicts all as ones if used with one-hot labels and categorical_crossentropy amnd softmax output

I have a simple tf.keras model:

inputs = keras.Input(shape=(9824,))
dense = layers.Dense(512, activation=keras.activations.relu, kernel_initializer=init)
x = dense(inputs)
x = layers.Dense(512, activation=keras.activations.relu)(x)
outputs = layers.Dense(3, activation=keras.activations.softmax)(x)
model = keras.Model(inputs=inputs, outputs=outputs)

When I compile it with sparse categorical crossentropy and actual labels it works as expected. But when I tried to one-hot encode the labels (with tf.keras.utils.to_categorical ) and use categorical_crossentropy (so I could use recall and precision as metrics during training), the model predicts everything as ones:

>>>print(predictions)
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 ...
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

If I understand correctly, softmax activation in the output layer should cause the output to be in range (0,1) and sum to 1. So, how is it possible that the class predictions are all 1? I was searching for an answer for hours but to no avail.

EDIT

Here is a minimalistic example .

I forgot to mention that I use scikeras package. Based on examples in the documentation , I assume the model is compiled implicitly. Here is the classifier constructor:

clf = KerasClassifier(
    model=keras_model_target,
    loss=SparseCategoricalCrossentropy(),
    name="model_target",
    optimizer=Adam(),
    init=GlorotUniform(),
    metrics=[SparseCategoricalAccuracy()],
    epochs=5,
    batch_size=128
)

I fit the model with

result = clf.fit(x_train, y_train)

and predict with:

predictions = clf.predict(x)

This was a bug in SciKeras that was fixed with the v0.3.1 release. Updating to the latest version should fix the issue.

As for the bug itself, it was due to how we were indexing numpy arrays, see this diff for details.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM