简体   繁体   中英

confusion_matrix error 'list' object has no attribute 'argmax'

I am working on a classification report for DCNN model, But I am facing an error. My code is

from sklearn.metrics import confusion_matrix

test = ImageDataGenerator()
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_data = test_generator.flow_from_directory(directory="/content/dataset/test",target_size=IMAGE_SHAPE , color_mode="rgb" , class_mode='categorical' , batch_size=1 , shuffle = False )

cm = confusion_matrix(test_labels, predictions.argmax(axis=1))


AttributeError: 'list' object has no attribute 'argmax'

Your predictions is obviously a Python list, and lists do not have an argmax attribute; you need to use the Numpy function argmax() :

predictions = [[0.1, 0.9], [0.8, 0.2]] # dummy data
y_pred_binary = predictions.argmax(axis=1)
# AttributeError: 'list' object has no attribute 'argmax'

# Use Numpy:
import numpy as np
y_pred_binary = np.argmax(predictions, axis=1)
# array([1, 0])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM