简体   繁体   中英

How can I make a colored Confusion Matrix (heatmap)? Or, show the full matrix?

I'm plotting a confusion matrix showing predicted vs actual of a total of 26 classes (26 letters of the alphabet).

My code is as follows:

y_pred = np.argmax(predictions, axis=1) # Transform predictions into 1-D array with label number

pd.DataFrame(confusion_matrix(y_test, y_pred), 
             columns=["pA", "pB", "pC", "pD", "pE", "pF", "pG", "pH", "pI", "pJ", "pK", "pL", "pM", "pN", "pO", "pP", "pQ", "pR", "pS", "pT", "pU", "pV", "pW", "pX", "pY", "pZ"],
             index=["aA", "aB", "aC", "aD", "aE", "aF", "aG", "aH", "aI", "aJ", "aK", "aL", "aM", "aN", "aO", "aP", "aQ", "aR", "aS", "aT", "aU", "aV", "aW", "aX", "aY", "aZ"])

My Output Looks like so:

在此处输入图像描述 The question is, how can I either display the entire rows/columns (without the... ) OR to make it more visually pleasing and easy to see, how can I change this to the colored version, I'm ok even if it doesn't show the numbers.

Thanks for taking the time to go over and help me with this, cheers!

you can treat the dataframe as an image:

import matplotlib.pyplot as plt

df = pd.DataFrame(confusion_matrix(y_test, y_pred), 
         columns=["pA", "pB", "pC", "pD", "pE", "pF", "pG", "pH", "pI", "pJ", "pK", "pL", "pM", "pN", "pO", "pP", "pQ", "pR", "pS", "pT", "pU", "pV", "pW", "pX", "pY", "pZ"],
         index=["aA", "aB", "aC", "aD", "aE", "aF", "aG", "aH", "aI", "aJ", "aK", "aL", "aM", "aN", "aO", "aP", "aQ", "aR", "aS", "aT", "aU", "aV", "aW", "aX", "aY", "aZ"])

plt.imshow(df[:])

edit:

you can add values using plt.annotate:

cols = ["pA", "pB", "pC", "pD", "pE", "pF", "pG", "pH", "pI", "pJ", "pK", "pL", "pM", "pN", "pO", "pP", "pQ", "pR", "pS", "pT", "pU", "pV", "pW", "pX", "pY", "pZ"]
ix=["aA", "aB", "aC", "aD", "aE", "aF", "aG", "aH", "aI", "aJ", "aK", "aL", "aM", "aN", "aO", "aP", "aQ", "aR", "aS", "aT", "aU", "aV", "aW", "aX", "aY", "aZ"]
data = np.random.randint(0,10,(len(cols), len(ix)))
df = pd.DataFrame(data=data, 
         columns=cols,
         index=ix)

plt.figure(figsize=(10,10))
plt.imshow(df[:])

for r, (i, row) in enumerate(df.iterrows()):
    for c, entry in enumerate(row):
        plt.annotate(entry, xy=(c-0.3,r+.1), fontsize=8) # using annotate with offset for xy, to allow better positioning

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM