简体   繁体   中英

Plot subplots inside subplots matplotlib

Context: I'd like to plot multiple subplots (sparated by legend) based on patterns from the columns of a dataframe inside a subplot however, I'm not being able to separate each subplots into another set of subplots.

This is what I have:

import matplotlib.pyplot as plt
col_patterns = ['pattern1','pattern2']
# define subplot grid
fig, axs = plt.subplots(nrows=len(col_patterns), ncols=1, figsize=(30, 80))
plt.subplots_adjust()
fig.suptitle("Title", fontsize=18, y=0.95)
for col_pat,ax in zip(col_patterns,axs.ravel()):
    col_pat_columns = [col for col in df.columns if col_pat in col]

    df[col_pat_columns].plot(x='Week',ax=ax)
    # chart formatting
    ax.set_title(col_pat.upper())
    ax.set_xlabel("")

Which results in something like this:

子图

How could I make it so that each one of those suplots turn into another 6 subplots all layed out horizontally? (ie each figure legend would be its own subplot)

Thank you!

In your example, you're defining a 2x1 subplot and only looping through two axes objects that get created. In each of the two loops, when you call df[col_pat_columns].plot(x='Week',ax=ax) , since col_pat_columns is a list and you're passing it to df , you're just plotting multiple columns from your dataframe. That's why it's multiple series on a single plot.

@fdireito is correct—you just need to set the ncols argument of plt.subplots() to the right number that you need, but you'd need to adjust your loops to accommodate.

If you want to stay in matplotlib, then here's a basic example. I had to take some guesses as to how your dataframe was structured and so on.

# import matplotlib
import matplotlib.pyplot as plt

# create some fake data
x = [1, 2, 3, 4, 5]

df = pd.DataFrame({
    'a':[1, 1, 1, 1, 1],    # horizontal line
    'b':[3, 6, 9, 6, 3],    # pyramid
    'c':[4, 8, 12, 16, 20], # steep line
    'd':[1, 10, 3, 13, 5]   # zig-zag
})

# a list of lists, where each inner list is a set of
# columns we want in the same row of subplots
col_patterns = [['a', 'b', 'c'], ['b', 'c', 'd']]

The following is a simplified example of what your code ends up doing.

fig, axes = plt.subplots(len(col_patterns), 1)

for pat, ax in zip(col_patterns, axes):
    ax.plot(x, df[pat])

2x1 subplot (what you have right now)

I use enumerate() with col_patterns to iterate through the subplot rows, and then use enumerate() with each column name in a given pattern to iterate through the subplot columns.

# the following will size your subplots according to
# - number of different column patterns you want matched (rows)
# - largest number of columns in a given column pattern (columns)
subplot_rows = len(col_patterns)
subplot_cols = max([len(x) for x in col_patterns])
fig, axes = plt.subplots(subplot_rows, subplot_cols)

for nrow, pat in enumerate(col_patterns):
    for ncol, col in enumerate(pat):
        axes[nrow][ncol].plot(x, df[col])

Correctly sized subplot

Here's all the code, with a couple additions I omitted from the code above for simplicity's sake.

import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]

df = pd.DataFrame({
    'a':[1, 1, 1, 1, 1],    # horizontal line
    'b':[3, 6, 9, 6, 3],    # pyramid
    'c':[4, 8, 12, 16, 20], # steep line
    'd':[1, 10, 3, 13, 5]   # zig-zag
})

col_patterns = [['a', 'b', 'c'], ['b', 'c', 'd']]

# what you have now
fig, axes = plt.subplots(len(col_patterns), 1, figsize=(12, 8))

for pat, ax in zip(col_patterns, axes):
    ax.plot(x, df[pat])
    ax.legend(pat, loc='upper left')

# what I think you want
subplot_rows = len(col_patterns)
subplot_cols = max([len(x) for x in col_patterns])

fig, axes = plt.subplots(subplot_rows, subplot_cols, figsize=(16, 8), sharex=True, sharey=True, tight_layout=True)

for nrow, pat in enumerate(col_patterns):
    for ncol, col in enumerate(pat):
        axes[nrow][ncol].plot(x, df[col], label=col)
        axes[nrow][ncol].legend(loc='upper left')

Another option you can consider is ditching matplotlib and using Seaborn relplots . There are several examples on that page that should help. If you have your dataframe set up correctly (long or "tidy" format), then to achieve the same as above, your one-liner would look something like this:

# import seaborn as sns

sns.relplot(data=df, kind='line', x=x_vals, y=y_vals, row=col_pattern, col=num_weeks_rolling)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM