[英]use the same color-mapping to table rows and a multi distplot
I have a dataframe, from which I take several data-groups and display as a displot on the same figure (overlayed).我有一个 dataframe,我从中获取几个数据组并在同一个图上显示为一个显示图(叠加)。 I also display a table summarizing some data regarding each group.
我还显示了一个表格,汇总了有关每个组的一些数据。 I would like to display each row in the table (=each group) in the same color as the matching displot color.
我想以与匹配的显示颜色相同的颜色显示表中的每一行(=每个组)。 I've tried to define a common colormap to both the table and the displot, however the displot throws an error:
我试图为表和 displot 定义一个通用的颜色图,但是 displot 会引发错误:
in distplot
if kde_color != color:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Process finished with exit code 1
Here is the code:这是代码:
fig, (ax_plot, ax_table) = plt.subplots(nrows=2, figsize=(11.69, 8.27),
gridspec_kw=dict(height_ratios=[3, 1]) )
ax_table.axis("off")
item_types = item_df['item_type'].unique()
columns = ('item type', 'Average DR', 'Percent DR passed 50%', 'Percent DR passed 60%', 'Percent DR passed 70%',
'Percent DR passed 80%')
cell_text = []
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, len(item_types)))
i=0
for item_type in item_types:
item_dr = item_df[item_df['item_type'] == item_type]['interesting_feature'].values
color = table_colors[i, 0:3]
sns.distplot(item_dr, hist=False, label=item_type, ax=ax_plot, color=mcolors.rgb_to_hsv(color))
i += 1
avg_dr = np.mean(item_dr)
pass50 = len(item_dr[item_dr > 0.5]) / len(item_dr)
pass60 = len(item_dr[item_dr > 0.6]) / len(item_dr)
pass70 = len(item_dr[item_dr > 0.7]) / len(item_dr)
pass80 = len(item_dr[item_dr > 0.8]) / len(item_dr)
cell_text.append([str(item_type), str(avg_dr), str(pass50), str(pass60), str(pass70), str(pass80)])
item_table = ax_table.table(cellText=cell_text,
colLabels=columns,
loc='center',
fontsize=20,
rowColours=table_colors)
First off, converting to hsv as in mcolors.rgb_to_hsv(color)
doesn't look very useful.首先,在
mcolors.rgb_to_hsv(color)
中转换为 hsv 看起来不是很有用。
Now, the main problem seems to be that passing a color as a list or a numpy array ( [1, 0, 0]
) confuses sns.distplot(..., color=color)
.现在,主要问题似乎是将颜色作为列表或 numpy 数组(
[1, 0, 0]
)传递会混淆sns.distplot(..., color=color)
。 Many seaborn functions allow either one color or a list of colors, and don't distinguish between a color passed as RGB values and an array.许多 seaborn 函数允许一种颜色或 colors 的列表,并且不区分作为 RGB 值传递的颜色和数组。 The workaround is to convert the list to a tuple:
sns.distplot(..., color=tuple(color))
.解决方法是将列表转换为元组:
sns.distplot(..., color=tuple(color))
。
Here is a minimal example:这是一个最小的例子:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))
fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
color = table_colors[i, 0:3]
# sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)
columns = list('abcdef')
num_columns = len(columns)
ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
colLabels=columns, loc='center', fontsize=20,
cellColours=np.repeat(table_colors, num_columns, axis=0).reshape(num_colors, num_columns, -1))
ax_table.axis('off')
plt.tight_layout()
plt.show()
To change the color of the text, you can loop through the cells of the table.要更改文本的颜色,您可以遍历表格的单元格。 As these particular colors are not very visible on a white background, the cell background could be set to black.
由于这些特殊的 colors 在白色背景上不是很明显,因此可以将单元格背景设置为黑色。
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
num_colors = 5
table_colors = plt.cm.BuPu(np.linspace(0, 0.5, num_colors))
fig, (ax_plot, ax_table) = plt.subplots(nrows=2)
for i in range(num_colors):
color = table_colors[i, :]
# sns.distplot(np.random.normal(0, 1, 100), hist=False, color=color) # gives an error
sns.distplot(np.random.normal(0, 1, 100), hist=False, color=tuple(color), ax=ax_plot)
columns = list('abcdef')
num_columns = len(columns)
table = ax_table.table(cellText=np.random.randint(1, 1000, size=(num_colors, num_columns)) / 100,
colLabels=columns, loc='center', fontsize=20)
for i in range(num_colors):
for j in range(num_columns):
table[(i+1, j)].set_color('black') # +1: skip the table header
table[(i+1, j)].get_text().set_color(table_colors[i, :])
ax_table.axis('off')
plt.tight_layout()
plt.show()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.