I have a question, the question text is:
Make a function plot_list(Xs, n_per_row)
that takes in input a list of numpy 2-dimensional arrays and a parameter n_per_row to set the number of elements display in a single row.
The output associated to the following list of 4 arrays should be as indicated in the figure below.
plot_list(Xs, n_per_row=2)
Xs =[
[[1 0 0 0]
[1 0 0 0]
[0 0 0 0]
[0 0 0 0]]
,
[[0 0 0 0]
[0 0 0 0]
[0 1 0 0]
[0 1 0 0]]
,
[[0 0 1 1]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
,
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 1 1]]
]
Xs=np.array(Xs)
should output an image like;
I have written the following function
def plot_list(Xs, n_per_row=5):
'''Takes an input of 4 arrays only, n_per_row'''
a= np.array(Xs)
fig, axs = plt.subplots(nrows=a.shape[0]//n_per_row, ncols=n_per_row ,figsize=(10,10))
for i, ax in enumerate(fig.axes):
if a.ndim==3:
ax.pcolormesh(a[i],cmap='Greys')
elif a.ndim<3:
ax.pcolormesh(a,cmap='Greys')
ax.grid(True, lw=1)
ax.set_ylim(ax.get_ylim()[::-1])
plt.show()
which when called as plot_list(Xs, n_per_row=2)
produces the following output ( notice the grids, not aligned the way I intended )
However, when I call the function to produce 4 results in one row instead of 2 results in the abomination below; plot_list(Xs, n_per_row=4)
notice the y axis is 8 unit long but x is 4. Anyone know how to fix this issue? The grid alignment and the shortening of the X axis?
Thanks
Few things that needs improvement:
set_major_locator
to force the gridfrom matplotlib import ticker as mticker def plot_list(Xs, n_per_row=5): '''Takes an input of 4 arrays only, n_per_row'''
a= np.array(Xs)
nrows = a.shape[0]//n_per_row
fig, axs = plt.subplots(nrows=nrows, ncols=n_per_row ,
figsize=(n_per_row*5,nrows*5)) # adjust the figsize here
for i, ax in enumerate(fig.axes):
if a.ndim==3:
ax.pcolormesh(a[i],cmap='Greys')
elif a.ndim<3:
ax.pcolormesh(a,cmap='Greys')
ax.grid(True, lw=1)
# set locator here
ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
ax.yaxis.set_major_locator(mticker.MultipleLocator(1))
ax.set_ylim(ax.get_ylim()[::-1])
ax.set_aspect('equal')
plt.show()
Then output:
and plot_list(Xs, n_per_row=4)
gives:
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.