简体   繁体   中英

Make a non symmetrical confusion matrix

I have a df with the columns: Site, True and Predicted. This has n samples from 4 sites (A, B, C, D), so 4n samples in total. Each site is either one of two classes (Sites A&B are class High, Sites C&D are class Low). Here are what the first 4 rows of my dataframe look like:

Site Label Prediction
A High Low
B High High
C Low Low
D Low Low
... ... ...

Instead of the normal 2x2 confusion matrix this can be used to generate, for example with 40 samples:

class High Low
Low 11 9
High 9 11

I want to create a 4x2 one, so I can see which sites are getting missclassified as high or low, so it would look something like this (using 10 samples from each site):

A B C D
Low 8 9 3 0
High 2 1 7 10

How can this be done? I would like it as a numpy matrix

Here's a longer screenshot of the csv, this has n samples: 在此处输入图像描述

I hope I understood your question correctly and you like to know the misspecified labels in each category for each site. If this is the case, then the following code might answer your question.

import numpy as np
from tabulate import tabulate
import pandas as pd

# Recreate data table
InputTable = {'Site': ['A', 'B', 'C', 'D', 'A', 'B', 'C', 'D', 'A', 'B', 'C', 'D'],\
        'Label': ['High', 'High', 'Low', 'Low','High', 'High', 'Low', 'Low','High', 'High', 'Low', 'Low'],\
        'Predicted': ['Low', 'High', 'Low', 'Low', 'Low', 'High', 'Low', 'Low', 'Low', 'High', 'Low', 'Low']}
df = pd.DataFrame(InputTable)

# Overview
SiteList = ['A', 'B', 'C', 'D']
Levels = ['Low', 'High']

# Preallocate
Result = np.empty((len(Levels), len(SiteList))); Result[:] = np.NaN

# Compute entries
for iter1 in range(len(Levels)):
    for iter2 in range(len(SiteList)):

        # Current selection
        SelectedLevel = Levels[iter1]
        SelectedSite= SiteList[iter2]

        # A misclassified level X has predicted label X and true label unequal to X
        Result[iter1, iter2] = len(df[ (df['Site'] == SelectedSite) & (df['Label'] != SelectedLevel) & (df['Predicted'] == SelectedLevel) ])

# Print result
print('\n')
print( tabulate(df, headers='keys', tablefmt='psql', showindex=True) )
print('\n')
print(pd.DataFrame(Result, index=Levels, columns=SiteList))

The result is

        A    B    C    D
Low   3.0  0.0  0.0  0.0
High  0.0  0.0  0.0  0.0

which is correct given that 3 predicted Low's are misspecified (only at Site A).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM