I'm using seaborn to plot the results of different algorithms. I want to distinguish both the different algorithms as well as their classification ("group"). The problem is that not all algorithms are in all groups, so when I use group as hue
, I get a lot of blank space:
import seaborn as sns
group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
results = [i+1 for i in range(len(group))]
sns.barplot(group, results, hue=alg)
As you can see, seaborn makes space for bars from all algorithms to be in all groups, leading to lots of blank space. How can I avoid that? I do want to show the different groups on the x-axis and distinguish the different algorithms by color/style. Algorithms my occur in multiple but not all groups. But I just want space for 2 bars in "Simple" and "Complex" and just for 1 in "Cool". Any solutions with pure matplotlib
are also welcome; it doesn't need to be seaborn. I'd like to keep the seaborn color palette though.
There doesn't seem to be a standard way to create this type of grouped barplot. The following code creates a list of positions for the bars, their colors, and lists for the labels and their positions.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Patch
group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
colors = plt.cm.tab10.colors
alg_cat = pd.Categorical(alg)
alg_colors = [colors[c] for c in alg_cat.codes]
results = [i + 1 for i in range(len(group))]
dist_groups = 0.4 # distance between successive groups
pos = (np.array([0] + [g1 != g2 for g1, g2 in zip(group[:-1], group[1:])]) * dist_groups + 1).cumsum()
labels = [g1 for g1, g2 in zip(group[:-1], group[1:]) if g1 != g2] + group[-1:]
label_pos = [sum([p for g, p in zip(group, pos) if g == label]) / len([1 for g in group if g == label])
for label in labels]
plt.bar(pos, results, color=alg_colors)
plt.xticks(label_pos, labels)
handles = [Patch(color=colors[c], label=lab) for c, lab in enumerate(alg_cat.categories)]
plt.legend(handles=handles)
plt.show()
While one could handle this case completely within matplotlib
and numpy
, I've solved it via pandas
. The reason being that you need to figure out a way to do the categorical groupings correctly, and this is one of the main advantages in pandas
.
So what I did, essentially, is create a DataFrame from your data, which is then grouped by - obviously - the group
category. While iterating over each enumerated category with index i=0,1,2,..
, we create a set of ax.bar()
plots, each confined to the interval [i-0.5, i+0,5]
. The colours are taken from the seaborn colormap, as requested, and are then also used in the end to create a custom legend.
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd
import numpy as np
group = ['Simple', 'Simple', 'Complex', 'Complex', 'Cool']
alg = ['Alg 1', 'Alg 2', 'Alg 3', 'Alg 4', 'Alg 2']
results = [i+1 for i in range(len(group))]
df = pd.DataFrame({'group':group, 'alg':alg, 'results':results})
## this next line is only required if you have a specific order in mind;
# else, the .groupby() method will sort alphabetically!
df['group'] = pd.Categorical(df['group'], ["Simple", "Complex", "Cool"])
## choose the desired seaborn color palette here:
palette = sns.color_palette("Paired", len(alg))
labels, levels = pd.factorize(df['alg'])
df['color'] = [palette[l] for l in labels]
gdf = df.groupby('group')
fig,ax=plt.subplots(figsize=(5,3))
xt = []
xtl = []
min_width = 1/max([len(item) for (key,item) in gdf])
for i,(key,item) in enumerate(gdf):
xt.append(i)
xtl.append(key)
## for each enumerated group, we need to create the proper x-scale placement
# such that each bar plot is centered around " i "
# i.e. two bars near " i = 0 " will be placed at [-0.25, 0.25] with widths of 0.5
# so that they won't collide with bars near " i = 1 "
# which themselves are placed at [0.75 1.25]
rel = np.linspace(0,1,len(item)+1)[:-1]
rel -= rel.mean()
rel +=i
w = 1/(len(item))
## note that the entire interval width (i.e. from i-0.5 to i+0.5) will be filled with the bars,
# meaning that the individual bar widths will vary depending on the number of bars.
# either adjust the bar width like this to add some whitespace:
# w *= 0.9
## or alternatively, you could use a fixed width instead:
# w = 0.4
## or, by pre-evaluating the minimal required bar width:
# w = min_width
ax.bar(rel,item['results'].values,alpha=1,width=w,color=item['color'])
leg = []
for i,l in enumerate(levels):
p = mpatches.Patch(color=palette[i], label=l)
leg.append(p)
ax.legend(handles=leg)
ax.set_xticks(xt)
ax.set_xticklabels(xtl)
ax.grid()
plt.show()
The result (using sns.color_palette("Paired")
and w=1/len(item)
) then looks like this:
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.