简体   繁体   中英

plotly.figure_factory.create_annotated_heatmap doesn't show figure with axis labels correctly

I want to show annontated heatmaps in a Plotly Dash app with annotations. The heatmap works totally fine, if I didn't add axis labels or if the labels weren't just string with only digits but if I added the axis labels, the figure is too small and the annotations aren't showen correctly. 如果我在标签上添加一个字符串,它会起作用

带有我真正想要显示的标签

I created a simple example in Dash to illustrate my problem.

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
import numpy as np

import pandas as pd

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv')
app.layout = html.Div([
    dcc.Graph(id='graph-with-slider'),
    dcc.Slider(
        id='year-slider',
        min=df['year'].min(),
        max=df['year'].max(),
        value=df['year'].min(),
        marks={str(year): str(year) for year in df['year'].unique()},
        step=None
    )
])


@app.callback(
    Output('graph-with-slider', 'figure'),
    Input('year-slider', 'value'))
def update_figure(selected_year):
    y = ['8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18']    
    x = ["023", "034", "045", "056", "067", "078"]
    
    
    z = [[4, 4, 2, 1, 0, 0],
        [0, 0, 0, 0, 0, 0], 
        [11, 2, 4, 0, 0, 1],
        [np.nan, 0, 0, 0, 0, 0], 
        [8, 1, 6, 1, 32, 3], 
        [5, 0, 0, 5, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [24, 2, 15, 1, 0, 5],
        [0, 0, 0, 0, 0, 0], 
        [0, 0, 8, 0, 7, 0],
        [0, 0, 0, 9, 0, 0]]
        
    ## it will work if i enabaled the next two lines of code
    #  or if axis labels where just numbers and not something like "032"
    
    
    #x=["add any non numerical string and it will work" + s  for s in x]
    #y=["add any non numerical string and it will work" + s for s in y]
    

    #fig = go.Figure(data=go.Heatmap(z=z))
    
    fig =ff.create_annotated_heatmap(z=z,x=x,y=y, colorscale  = ["green", "yellow", "orange", "red"], showscale = True)
    
    layout = go.Layout(width=500, height=500,
            hovermode='closest',
            autosize=True,
            xaxis=dict(zeroline=False),
            yaxis=dict(zeroline=False, autorange='reversed')
         )
    fig = go.Figure(data=fig, layout=layout)
  
    
    
    return fig


if __name__ == '__main__':
    app.run_server(debug=True)
    
    
    ```

I couldn't replicate your exact issue but I have seen similar

Try: fig.update_xaxes(type='category')

Plotly does some work in the background if it thinks it can force your axis to be 'linear'. The type label may avoid this.

Some background here Plotly Categorical Axes

It's a bug. The problem is in the '_AnnotatedHeatmap' class (pretty sure) operating behind the scenes. The 'create_annotated_heatmap' function is not a very smart function. If you look at the code, it's surprisingly straightforward (seemed to me like more of a quick and dirty solution than the thoroughly developed streamlined plotly stuff). I didn't bother trying to work it out, just worked around. Here's my solution:

fig = ff.create_annotated_heatmap(z=z,
                                  colorscale='greens')
fig.update_layout(overwrite=True,
                  xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)", tickvals=list(range(len(x))), ticktext=x)
                  yaxis=dict(ticks="", dtick=1, ticksuffix="   ", tickvals=list(range(len(y))), ticktext=y))
fig.show()

The problem with this work around is that the built-in data-to-axis mapping doesn't happen. You have to map the data yourself. Make sure you sort your axes and have your data sorted that way too!! I haven't verified that it maps the data right, but I think it does. To ease this issue, here's the code I wrote for this:

x = list(sorted(df['X'].unique().tolist()))
y = list(sorted(df['Y'].unique().tolist()))

z = list()
iter_list = list()
for y_item in y:
    iter_list.clear()
    for x_item in x:
        z_data_point = df[(df['X'] == x_item) & (df['Y'] == y_item)]['Z']
        iter_list.append(0 if len(z_data_point) == 0 else z_data_point.iloc[0])
    z.append([_ for _ in iter_list])

clear as mud? Hope I typed that out right. If not, please let me know in the comments. I'm sure this last bit of code could be improved so if you know how, feel free to leave it in a comment and I'll update it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM