I've created a function that plays an animation using a slider. Each frame consists of a heatmap (with colorbar) above a barplot. The function's arguments consists of a list of text labels to be used both for the heatmap axes labels as well as the barplot horizontal axis labels, a list of matrices, and a list of lists to be used for the barplot. Also, there is a time window value, labelled win_value
, so that frame zero corresponds to time zero, frame one corresponds to win_value
, frame two to 2* win_value
, and so on.
The code for the function is as follows:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
num_times=len(M_list)
fig, ax = plt.subplots(2)
plt.subplots_adjust(left=None, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
s_time = Slider(ax_time, 'Time', 0, num_times, valinit=0,valstep=1)
def update_graph(val):
i= s_time.val
ax[0].cla()
heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')
ax[0].set_xticks(range(len(labels)))
ax[0].set_xticklabels(labels,fontsize=10,)
a.x[0].set_yticks(range(len(labels)))
ax[0].set_yticklabels(labels,fontsize=10)
ax0_divider = make_axes_locatable(ax[0])
cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')
ax[1].cla()
ax[1].bar(labels,bar_list[i-1])
ax[1].set_ylim(0, 1)
plt.show()
s_time.on_changed(update_graph)
s_time.set_val(0)
An example with seven labels, 10 frames, and window value.25 seconds:
import random
labels=['a','b','c','d','e','f','g','h']
M_list=[np.random.rand(7,7) for i in range(10)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25
heatmap_barplot_animation(labels,M_list,bar_list,win_value)
The third frame of the animation looks like this:
I can't seem to figure out what modifcations are needed to do the following:
For your first question, one way to center your slider on the subplots would be to simply adjust the position of your subplots with plt.subplots_adjust
to match the axes of the sliders. In your code the axes of your sliders are defined with: ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
so you might want to adjust your subplots with plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
. You can play around with the axes of your slider and the axes of your subplots to get the results you want (see below for an example with the slider centered on the subplots).
In response to your second question, to relabel the values from the slider you just need to change the valmax
and valstep
values of your labels to valmax=num_times*win_value
and valstep=win_value
. To index your M_list
and bar_list
arrays you then need to declare i
as i=int(s_time.val/win_value)
.
For more details see below the code you provided after implementing the modifications described above:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import Slider
import random
def heatmap_barplot_animation(labels,M_list,bar_list,win_value):
num_times=len(M_list)
fig, ax = plt.subplots(2)
plt.subplots_adjust(left=0.25, bottom=.2, right=None, top=.9, wspace=.2, hspace=.2)
ax_time=fig.add_axes([0.25, 0.1, 0.65, 0.03])
s_time = Slider(ax_time, 'Time',valinit=0,valmin=0,valmax=num_times*win_value,valstep=win_value)
def update_graph(val):
i=int(s_time.val/win_value)
ax[0].cla()
heatmap=ax[0].imshow(M_list[i-1*1],vmin=0, vmax=1, cmap='coolwarm', aspect='auto')
ax[0].set_xticks(range(len(labels)))
ax[0].set_xticklabels(labels,fontsize=10,)
ax[0].set_yticks(range(len(labels)))
ax[0].set_yticklabels(labels,fontsize=10)
ax0_divider = make_axes_locatable(ax[0])
cax0 = ax0_divider.append_axes('right', size='7%', pad='2%')
cb = fig.colorbar(heatmap, cax=cax0, orientation='vertical')
ax[1].cla()
ax[1].bar(labels,bar_list[i-1])
ax[1].set_ylim(0, 1)
plt.show()
s_time.on_changed(update_graph)
s_time.set_val(0)
labels=['a','b','c','d','e','f','g','h']
Nc=8
Nt=10
M_list=[np.random.rand(Nc,Nc) for i in range(Nt)]
bar_list=[[random.uniform(0,1) for i in range(Nc)] for t in range(Nt)]
win_value=.25
heatmap_barplot_animation(labels,M_list,bar_list,win_value)
And the output gives (at frame number 3):
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.