Background of the problem:
I'm working on a class that takes an Axes object as constructor parameter and produces a (m,n) dimension figure with a histogram in each cell , kind of like the figure below:
There are two important things to note here, that I'm not allowed to modified in any way:
_, ax = plt.subplots() # By default takes (1,1) dimension
cm = ClassName(model, ax=ax, histogram=True) # calling my class
What I'm stuck on:
Since I want to plot histograms within each cell, I decided to approach it by looping over each cell and creating a histogram for each .
results[col].hist(ax=self.ax[y,x], bins=bins)
However, I'm not able to specify the axes of the histogram in any way. This is because the Axes parameter passed is of default dimension (1,1) and hence not index-able . When I try this I get a TypeError saying.
TypeError: 'AxesSubplot' object is not subscriptable
With all this considered, I would like to know of any possible ways I can add my histogram to the parent Axes object. Thanks for taking a look.
The requirement is pretty strict and maybe not the best design choice. Because you later want to plot several subplots at the position of a single subplot, this single subplot is only created for the sole purpose of dying and being replaced a few moments later.
So what you can do is obtain the position of the axes you pass in and create a new gridspec at that position. Then remove the original axes and create a new set of axes at within that newly created gridspec.
The following would be an example. Note that it currently requires that the axes to be passed in is a Subplot
(as opposed to any axes). It also hardcodes the number of plots to be 2*2
. In the real use case you would probably derive that number from the model
you pass in.
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
class ClassName():
def __init__(self, model, ax=None, **kwargs):
ax = ax or plt.gca()
if not hasattr(ax, "get_gridspec"):
raise ValueError("Axes needs to be a subplot")
parentgs = ax.get_gridspec()
q = ax.get_geometry()[-1]
# Geometry of subplots
m, n = 2, 2
gs = gridspec.GridSpecFromSubplotSpec(m,n, subplot_spec=parentgs[q-1])
fig = ax.figure
ax.remove()
self.axes = np.empty((m,n), dtype=object)
for i in range(m):
for j in range(n):
self.axes[i,j] = fig.add_subplot(gs[i,j], label=f"{i}{j}")
def plot(self, data):
for ax,d in zip(self.axes.flat, data):
ax.plot(d)
_, (ax,ax2) = plt.subplots(ncols=2)
cm = ClassName("mymodel", ax=ax2) # calling my class
cm.plot(np.random.rand(4,10))
plt.show()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.