简体   繁体   中英

Retention heatmap in plotly

For convenience, I transfer the retention graph from Seaborn to Plotly, so that I can apply shapes to it later. The plotly library seems to be suitable for this.graph_objects, but I don't understand how to pass DataFrame data.

import pandas as pd
import numpy as np
import seaborn as sns
import plotly as ply
import matplotlib.pyplot as plt
import plotly.graph_objects as go

df=pd.DataFrame(index=['01.2020','02.2020','03.2020','04.2020','05.2020','06.2020'],
                data={0:[1,1,1,1,1,1],
                    1:[0.58, 0.88, 0.27, 0.28, 0.68,0.90],
                    2:[0.56, 0.58, 0.1, 0.77, 0.68,None],
                    3:[0.78, 0.33, 0.4, 0.79, None,None],
                    4:[0.58, 0.16, 0.89, None, None,None],
                    5:[0.25, 0.14, None, None, None,None],
                    6:[0.69, None, None, None, None,None] })

sns.set(style='white')
plt.figure(figsize=(12, 8))
plt.title('Cohorts: User Retention')
sns.heatmap(df,annot=True, fmt='.0%');

How I can do it in Plotly?

There is already an answer that can help you see . But it's kind of outdated as many methods there are already deprecated. Actually, as long as you are fine changing your scale from 0-1 to 0-100 you could use plotly.figure_factory.create_annotated_heatmap but as far as I know all figure_factory are going to be deprecated soon. So the bad thing is that you are forced to write annotations (the text) manually as following.

import pandas as pd
import numpy as np
import plotly.graph_objects as go

df = pd.DataFrame(index=['01.2020','02.2020','03.2020','04.2020','05.2020','06.2020'],
                  data={0:[1,1,1,1,1,1],
                        1:[0.58, 0.88, 0.27, 0.28, 0.68,0.90],
                        2:[0.56, 0.58, 0.1, 0.77, 0.68,None],
                        3:[0.78, 0.33, 0.4, 0.79, None,None],
                        4:[0.58, 0.16, 0.89, None, None,None],
                        5:[0.25, 0.14, None, None, None,None],
                        6:[0.69, None, None, None, None,None] })

z = df.values
x = df.columns
y = df.index
annotations = []
for n, row in enumerate(z):
    for m, val in enumerate(row):
        annotations.append(
            dict(text="{0:.0%}".format(z[n][m]) if not np.isnan(z[n][m]) else '',
                 x=x[m],
                 y=y[n],
                 xref='x1',
                 yref='y1',
                 showarrow=False))

layout = dict(title='Cohorts: User Retention',
              title_x=0.5,
              annotations=annotations,
              yaxis=dict(showgrid=False,
                         tickmode='array',
                         tickvals=np.arange(1,len(y)+1),
                         ticktext=y
                        ),
              xaxis=dict(showgrid=False),
              width=700,
              height=700,
              autosize=False
             )


trace = go.Heatmap(x=x, y=y, z=z)
fig = go.Figure(data=trace, layout=layout)
fig.show()

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM