簡體   English   中英

如何注釋分類軸上的 swarmplot 點和來自不同列的標簽

[英]How to annotate swarmplot points on a categorical axis and labels from a different column

我正在嘗試為我的 matplotlib/seaborn plot 中的一些值添加標簽。 不是全部,只是高於某個值的那些(下面,使用來自 sklearn 的 iris,在 x 軸上標記大於 3.6 的值)。

在這里,來自@Scinana,去年討論了當兩個軸都是數字時這樣做 但是,雖然它包含一個公認的答案,但我無法適應我的情況。 接受的答案中提供的鏈接也無濟於事。

下面的代碼一直有效,直到最后一步(標記),它會拋出: 'TypeError: 'FacetGrid' object is not callable'

此外,異常值需要使用dfiris['sepal length (cm)']中的值進行注釋,而不僅僅是'outliers'

import sklearn as sklearn 
from sklearn.datasets import load_iris

dfiris = load_iris()
dfiris = pd.DataFrame(data=dfiris.data, columns=dfiris.feature_names)
dfiris['name'] = np.where(dfiris['sepal width (cm)'] < 3, 'Amy', 'Bruce')  # adding a fake categorical variable 
dfiris['name'] = np.where((dfiris.name != 'Amy') & (dfiris['petal length (cm)'] >= 3.4), 'Charles', dfiris.name) # adding to that fake categorical variable 

a_viz = sns.catplot(x='sepal width (cm)', y= 'name', kind = 'swarm', data=dfiris)
a_viz.fig.set_size_inches(5, 6)
a_viz.fig.subplots_adjust(top=0.81, right=0.86)

for x, y in zip(dfiris['sepal width (cm)'], dfiris['name']):
    if x > 3.6:
        a_viz.text(x, y, 'outlier', horizontalalignment='left', size='medium', color='black')

以下副本並未完全解決從不同列添加注釋的問題,也沒有完全解決如何防止注釋重疊的問題。

  • 對於swarmplot ,無法區分每個觀察值在獨立軸上的刻度位置,這意味着 x 軸上每個值的文本注釋將重疊。
    • 這可以通過使用pandas.DataFrame.groupby創建要傳遞給s=的字符串來解決。

非重疊注釋

import seaborn as sns

# load sample data that has text labels
df = sns.load_dataset('iris')

# plot the DataFrame
g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)

# there is only one axes for this plot; provide an alias for ease of use
ax = g.axes[0, 0]

# get the ytick locations for each name
ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}

# add the ytick locations for each observation
df['ytick_loc'] = df.species.map(ytick_loc)

# filter the dataframe to only contain the outliers
outliers = df[df.sepal_width.gt(3.6)].copy()

# convert the column to strings for annotations
outliers['sepal_length'] = outliers['sepal_length'].astype(str)

# combine all the sepal_length values as a single string for each species and width
labels = outliers.groupby(['sepal_width', 'ytick_loc']).agg({'sepal_length': '\n'.join}).reset_index()

# iterate through each axes of the FacetGrid with `for ax in g.axes.flat:` or specify the exact axes to use
for _, (x, y, s) in labels.iterrows():
    ax.text(x + 0.01, y, s=s, horizontalalignment='left', size='medium', color='black', verticalalignment='center', linespacing=1)

在此處輸入圖像描述


DataFrame 瀏覽量

df

   sepal_length  sepal_width  petal_length  petal_width species  ytick_loc
0           5.1          3.5           1.4          0.2  setosa          0
1           4.9          3.0           1.4          0.2  setosa          0
2           4.7          3.2           1.3          0.2  setosa          0
3           4.6          3.1           1.5          0.2  setosa          0
4           5.0          3.6           1.4          0.2  setosa          0

outliers

    sepal_length  sepal_width  petal_length  petal_width    species  ytick_loc
5            5.4          3.9           1.7          0.4     setosa          0
10           5.4          3.7           1.5          0.2     setosa          0
14           5.8          4.0           1.2          0.2     setosa          0
15           5.7          4.4           1.5          0.4     setosa          0
16           5.4          3.9           1.3          0.4     setosa          0
18           5.7          3.8           1.7          0.3     setosa          0
19           5.1          3.8           1.5          0.3     setosa          0
21           5.1          3.7           1.5          0.4     setosa          0
32           5.2          4.1           1.5          0.1     setosa          0
33           5.5          4.2           1.4          0.2     setosa          0
44           5.1          3.8           1.9          0.4     setosa          0
46           5.1          3.8           1.6          0.2     setosa          0
48           5.3          3.7           1.5          0.2     setosa          0
117          7.7          3.8           6.7          2.2  virginica          2
131          7.9          3.8           6.4          2.0  virginica          2

labels

   sepal_width  ytick_loc        sepal_length
0          3.7          0       5.4\n5.1\n5.3
1          3.8          0  5.7\n5.1\n5.1\n5.1
2          3.8          2            7.7\n7.9
3          3.9          0            5.4\n5.4
4          4.0          0                 5.8
5          4.1          0                 5.2
6          4.2          0                 5.5
7          4.4          0                 5.7

重疊注釋

import seaborn as sns

# load sample data that has text labels
df = sns.load_dataset('iris')

# plot the DataFrame
g = sns.catplot(x='sepal_width', y='species', kind='swarm', data=df, height=7, aspect=2)

# there is only one axes for this plot; provide an alias for ease of use
ax = g.axes[0, 0]

# get the ytick locations for each name
ytick_loc = {v.get_text(): v.get_position()[1] for v in ax.get_yticklabels()}

# plot the text annotations
for x, y, s in zip(df.sepal_width, df.species.map(ytick_loc), df.sepal_length):
    if x > 3.6:
        ax.text(x, y, s, horizontalalignment='left', size='medium', color='k')

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM