简体   繁体   中英

Numpy and Matplotlib, Printing a matrix with imshow or pcolor Problem

I have a question, the question text is:

Make a function plot_list(Xs, n_per_row) that takes in input a list of numpy 2-dimensional arrays and a parameter n_per_row to set the number of elements display in a single row.

The output associated to the following list of 4 arrays should be as indicated in the figure below.

plot_list(Xs, n_per_row=2)

Xs =[
[[1 0 0 0]
 [1 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
,
[[0 0 0 0]
 [0 0 0 0]
 [0 1 0 0]
 [0 1 0 0]]
,
[[0 0 1 1]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]]
,
[[0 0 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 0 1 1]]
]

Xs=np.array(Xs)

should output an image like;

在此处输入图像描述

I have written the following function

def plot_list(Xs, n_per_row=5):
    '''Takes an  input of 4 arrays only, n_per_row'''
    
    a= np.array(Xs)
    
    fig, axs = plt.subplots(nrows=a.shape[0]//n_per_row, ncols=n_per_row ,figsize=(10,10))

    for i, ax in enumerate(fig.axes):
        if a.ndim==3:
            ax.pcolormesh(a[i],cmap='Greys')
        elif a.ndim<3:
            ax.pcolormesh(a,cmap='Greys')
        ax.grid(True, lw=1)
        ax.set_ylim(ax.get_ylim()[::-1])
    
    plt.show()

which when called as plot_list(Xs, n_per_row=2)

produces the following output ( notice the grids, not aligned the way I intended )

在此处输入图像描述

However, when I call the function to produce 4 results in one row instead of 2 results in the abomination below; plot_list(Xs, n_per_row=4)

在此处输入图像描述

notice the y axis is 8 unit long but x is 4. Anyone know how to fix this issue? The grid alignment and the shortening of the X axis?

Thanks

Few things that needs improvement:

  1. Change the figure size according to the layout
  2. use set_major_locator to force the grid

from matplotlib import ticker as mticker def plot_list(Xs, n_per_row=5): '''Takes an input of 4 arrays only, n_per_row'''

a= np.array(Xs)

nrows = a.shape[0]//n_per_row

fig, axs = plt.subplots(nrows=nrows, ncols=n_per_row ,
                        figsize=(n_per_row*5,nrows*5))       # adjust the figsize here


for i, ax in enumerate(fig.axes):
    if a.ndim==3:
        ax.pcolormesh(a[i],cmap='Greys')
    elif a.ndim<3:
        ax.pcolormesh(a,cmap='Greys')
    ax.grid(True, lw=1)

    # set locator here
    ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(mticker.MultipleLocator(1))
    ax.set_ylim(ax.get_ylim()[::-1])

    ax.set_aspect('equal')

plt.show()

Then output:

在此处输入图像描述

and plot_list(Xs, n_per_row=4) gives:

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM