简体   繁体   中英

How to create a heatmap animation with plotly express?

I have a list of square matrices, M[t], where t ranges from 0 to N and I wish to create an animated heatplot using plotly.express. The entries in each row/column correspond to a list, a=['a1','a2',...'aN']

The plotly documentation on animation is fairly sparse and focuses on just scatterplots and barplots

https://plotly.com/python/animations/

A question similar to mine was posted at

How to animate a heatmap in Plotly

However, the user is working in a Jupyter notebook. I'm simply using Python 3.7 with IDLE on a Mac (OS 10.15.4)

I know how to create a basic animation using matplotlib or seaborn, but I like the built-in start/stop buttons that come with plotly express. Here's one approach I use, but I'm sure there are more efficient ways using matplotlib.animation:

import numpy as np
import matplotlib.pyplot as plt
#50 matrices, each of size 4-by-4.
N = 50
M = np.random.random((50, 4,4))

#Desired labels for heatmap--not sure where to put.
labels=['a','b','c','d']

fig, ax = plt.subplots()

for t in range(50):
    ax.cla()
    ax.imshow(M[t])
    ax.set_title("frame {}".format(t))
    plt.pause(0.1)

Does this work for you?

import numpy as np
import plotly.graph_objs as go

N = 50
M = np.random.random((N, 10, 10))

fig = go.Figure(
    data=[go.Heatmap(z=M[0])],
    layout=go.Layout(
        title="Frame 0",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=[go.Heatmap(z=M[i])],
                     layout=go.Layout(title_text=f"Frame {i}")) 
            for i in range(1, N)]
)

fig.show()

UPDATE In case you need to add a Pause button

fig = go.Figure(
    data=[go.Heatmap(z=M[0])],
    layout=go.Layout(
        title="Frame 0",
        title_x=0.5,
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None]),
                    dict(label="Pause",
                         method="animate",
                         args=[None,
                               {"frame": {"duration": 0, "redraw": False},
                                "mode": "immediate",
                                "transition": {"duration": 0}}],
                         )])]
    ),
    frames=[go.Frame(data=[go.Heatmap(z=M[i])],
                     layout=go.Layout(title_text=f"Frame {i}")) 
            for i in range(1, N)]
)

fig.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM