简体   繁体   中英

How to plot confusion matrix in form of np.array as categorical heatmap in bokeh?

My data:

short_labels = ["PERSON", "CITY", "COUNTRY", "O"]

cm = np.array([[0.53951528, 0.        , 0.        , 0.46048472],
               [0.06407323, 0.18077803, 0.        , 0.75514874],
               [0.06442577, 0.00560224, 0.08963585, 0.84033613],
               [       nan,        nan,        nan,        nan]])

I need to create figure in bokeh looking as seaborn.heatmap(cm) below:

在此处输入图像描述

I would be extremely grateful for your help!

Code below:

colors = ["#0B486B", "#79BD9A", "#CFF09E", "#79BD9A", "#0B486B", "#79BD9A", "#CFF09E", "#79BD9A", "#0B486B"]

rand_cm_df = pd.DataFrame(rand_cm, columns=short_labels)
rand_hm = figure(title=f"Title",
                 toolbar_location=None,
                 x_range=short_labels,
                 y_range=short_labels,
                 output_backend="svg")
rand_hm.rect(source=ColumnDataSource(rand_cm_df), color=colors, width=1, height=1)

returns following error:

RuntimeError: 

Expected line_color, hatch_color and fill_color to reference fields in the supplied data source.

When a 'source' argument is passed to a glyph method, values that are sequences
(like lists or arrays) must come from references to data columns in the source.

For instance, as an example:

    source = ColumnDataSource(data=dict(x=a_list, y=an_array))

    p.circle(x='x', y='y', source=source, ...) # pass column names and a source

Alternatively, *all* data sequences may be provided as literals as long as a
source is *not* provided:

    p.circle(x=a_list, y=an_array, ...)  # pass actual sequences and no source

You were getting there. Mostly needed to transform your data around. You'll need to muck w/ the attributes some more to clean it up, but this should get you up and running.

import pandas as pd
import numpy as np
from bokeh.models import LinearColorMapper
from bokeh.plotting import figure, show
from bokeh.palettes import YlGn9

short_labels = ["PERSON", "CITY", "COUNTRY", "O"]

cm = np.array([[0.53951528, 0.        , 0.        , 0.46048472],
               [0.06407323, 0.18077803, 0.        , 0.75514874],
               [0.06442577, 0.00560224, 0.08963585, 0.84033613],
               [np.nan,     np.nan,     0.,     0.]])

cm_df = pd.DataFrame(cm, index=short_labels, columns=short_labels)
cm_df = pd.DataFrame(cm_df.stack(),columns=['val']).reset_index()

colors = list(YlGn9)[::-1]
mapper = LinearColorMapper(palette=colors, low=0.0, high=1.0)

rand_hm = figure(
    title=f"Title",
    toolbar_location=None,
    x_range=short_labels,
    y_range=list(reversed(short_labels)),
    )

rand_hm.rect(
    source=cm_df,
    x='level_1',
    y='level_0',
    width=1, 
    height=1, 
    fill_color={'field':'val','transform': mapper},
    )

rand_hm.text(
    source=cm_df,
    x='level_1',
    y='level_0',
    text='val',
    text_font_size='10pt',
    x_offset=-30,
    y_offset=10
    )
    
show(rand_hm)

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM