简体   繁体   中英

Matplotlib: Plotting multiple histograms in plt.subplots

Background of the problem:

I'm working on a class that takes an Axes object as constructor parameter and produces a (m,n) dimension figure with a histogram in each cell , kind of like the figure below:

在此处输入图片说明

There are two important things to note here, that I'm not allowed to modified in any way:

  1. The Figure object is not passed as a constructor parameter; only the Axes object is. So the subplots object cannot be modified in any way .
  2. The Axes parameter is set to that of a (1,1) figure, by default (as below). All the modification required to make it an (m,n) figure are performed within the class (inside its methods)
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class

What I'm stuck on:

Since I want to plot histograms within each cell, I decided to approach it by looping over each cell and creating a histogram for each .

results[col].hist(ax=self.ax[y,x], bins=bins)

However, I'm not able to specify the axes of the histogram in any way. This is because the Axes parameter passed is of default dimension (1,1) and hence not index-able . When I try this I get a TypeError saying.

TypeError: 'AxesSubplot' object is not subscriptable

With all this considered, I would like to know of any possible ways I can add my histogram to the parent Axes object. Thanks for taking a look.

The requirement is pretty strict and maybe not the best design choice. Because you later want to plot several subplots at the position of a single subplot, this single subplot is only created for the sole purpose of dying and being replaced a few moments later.

So what you can do is obtain the position of the axes you pass in and create a new gridspec at that position. Then remove the original axes and create a new set of axes at within that newly created gridspec.

The following would be an example. Note that it currently requires that the axes to be passed in is a Subplot (as opposed to any axes). It also hardcodes the number of plots to be 2*2 . In the real use case you would probably derive that number from the model you pass in.

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec

class ClassName():
    def __init__(self, model, ax=None, **kwargs):
        ax = ax or plt.gca()
        if not hasattr(ax, "get_gridspec"):
            raise ValueError("Axes needs to be a subplot")
        parentgs = ax.get_gridspec()
        q = ax.get_geometry()[-1]

        # Geometry of subplots
        m, n = 2, 2
        gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
        fig = ax.figure
        ax.remove()

        self.axes = np.empty((m,n), dtype=object)
        for i in range(m):
            for j in range(n):
                self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")


    def plot(self, data):
        for ax,d in zip(self.axes.flat, data):
            ax.plot(d)


_, (ax,ax2) = plt.subplots(ncols=2) 
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))

plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM