简体   繁体   中英

Confusion Matrix produces different results in Keras model tf==2.3.0

With Keras Sequential Model Prediction,

To get Class Labels we can do

yhat_classes1 = Keras_model.predict_classes(predictors)[:, 0] #this shows deprecated warning in tf==2.3.0

WARNING:tensorflow:From <ipython-input-54-226ad21ffae4>:1: Sequential.predict_classes (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.
Instructions for updating:
Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).

so

yhat_classes2 = np.argmax(Keras_model.predict(predictors), axis=-1)

With the first class labels if i create confusion matrix, i get

matrix = confusion_matrix(actual_y, yhat_classes1)
 [[108579   8674]
 [  1205  24086]]

But with the second class labels with the confusion matrix, i get 0 for True Positive and False Positive

matrix = confusion_matrix(actual_y, yhat_classes2)
 [[117253      0]
 [ 25291      0]]

May I know whats my issue here?

The confusion matrix returns 2 rows/columns, which leads me to believe that you have two classes. The warning specifically says that you should use this line for binary classification, which is what you're doing:

(model.predict(x) > 0.5).astype("int32")

Please use instead:* np.argmax(model.predict(x), axis=-1) , if your model does multi-class classification (eg if it uses a softmax last-layer activation).* (model.predict(x) > 0.5).astype("int32") , if your model does binary classification (eg if it uses a sigmoid last-layer activation).

The error is that you used np.argmax(model.predict(X), axis=-1) on a 1D output and so this always returns the same column (because there's only one, so the maximum value will be in that column). That explains that all your predicted values are in the same column in your confusion matrix.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM