繁体   English   中英

如何使用 plot 在具有特定 colors 和每条线的线型的特定点上使用 seaborn 绘制线图?

[英]How to plot a lineplot with dots on specific points with specific colors and linetypes for each line using seaborn?

我有以下 dataframe

import pandas as pd



 data_tmp = pd.DataFrame({'x': [0,14,28,42,56, 0,14,28,42,56],
                         'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
                         'cat': ['A','A','A','A','A','B','B','B','B','B'],
                         'color': ['#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#247AB2','#247AB2','#247AB2','#247AB2','#247AB2'],
                         'point': [14,14,14,14,14,28,28,28,28,28],
                         'linestyles':['-','-','-','-','-','--','--','--','--','--']})

我想为每只cat制作一个具有不同colorlinestyles样式的线图。 但我想为每cat提供特定的colordataframe linestyles定义。 最后,我想用相同的颜色标记每条线上的point s。

我只试过:

sns.lineplot(x="x", y="y", hue="cat", data=data_tmp)
sns.scatterplot(x="point",y="y",hue="cat", data=data_tmp[data_tmp.point==data_tmp.x])
plt.show()

有任何想法吗?

也许你想直接使用 matplotlib ,比如

import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame({'x': [0,14,28,42,56, 0,14,28,42,56],
                   'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
                   'cat': ['A','A','A','A','A','B','B','B','B','B'],})


d = {"A" : {"color": '#B5D8F0', "markersize":  5, "linestyle": "-"},
     "B" : {"color": '#247AB2', "markersize": 10, "linestyle": "--"}}

for n, grp in df.groupby("cat"):
    plt.plot(grp.x, grp.y, marker="o", label=n, **d[n])

plt.legend()
plt.show()

在此处输入图像描述

这就是我能做到的。 您需要使用cat列来控制不同的 plot 参数(颜色、样式、标记大小),然后创建映射对象(此处为 dicts)来告诉每个类别使用哪个参数值。 颜色很简单。 线条样式更难,因为 Seaborn 仅提供dashes作为可配置参数,需要在(segment, gap)的高级 Matplotlib 格式中给出。 function matplotlib.lines._get_dash_pattern将字符串值(例如-- )转换为这种格式,尽管返回值需要小心处理。 对于标记大小,不幸的lineplot不提供随类别更改标记大小的可能性(即使您可以更改标记样式),因此您需要在顶部使用scatterplot 最后一位是图例,您可能希望为第二个 plot 禁用它,以避免重复,但问题是第一个图例中没有标记。 如果这让您感到困扰,您仍然可以手动编辑图例。 总而言之,它可能看起来像这样:

import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

# Converts a line style to a format acceptable by Seaborn
def get_dash_pattern(style):
    _, dash = mpl.lines._get_dash_pattern(style)
    return dash if dash else (None, None)

data_tmp = pd.DataFrame({
    'x': [0,14,28,42,56, 0,14,28,42,56],
    'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
    'cat': ['A','A','A','A','A','B','B','B','B','B'],
    'color': ['#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0',
              '#247AB2','#247AB2','#247AB2','#247AB2','#247AB2'],
    'point': [14,14,14,14,14,28,28,28,28,28],
    'linestyles':['-','-','-','-','-','--','--','--','--','--']})
# Extract plot features as dicts
feats = (data_tmp[['cat', 'color', 'linestyles', 'point']]
         .set_index('cat').drop_duplicates().to_dict())
palette, dashes, sizes = feats['color'], feats['linestyles'], feats['point']
# Convert line styles to dashes
dashes = {k: get_dash_pattern(v) for k, v in dashes.items()}
# Lines
lines = sns.lineplot(x="x", y="y", hue="cat", style="cat", data=data_tmp,
                     palette=palette, dashes=dashes)
# Points
sns.scatterplot(x="x", y="y", hue="cat", size="cat", data=data_tmp,
                palette=palette, sizes=sizes, legend=False)
# Fix legend
for t, l in zip(lines.legend().get_texts(), lines.legend().get_lines()):
    l.set_marker('o')
    l.set_markersize(sizes.get(l.get_label(), 0) / t.get_fontsize())
plt.show()

Output:

最终情节

这是我在@jdehesa 的帮助下的解决方案

我还将图例放在 plot 之外,并对标签进行了一些抛光

def get_dash_pattern(style):
    _, dash = mpl.lines._get_dash_pattern(style)
    return dash if dash else (None, None)

palette = dict(zip(data_tmp.cat, data_tmp.color))
dashes = dict(zip(data_tmp.cat, data_tmp.linestyles))
dashes = {k: get_dash_pattern(v) for k, v in dashes.items()}

ax = sns.lineplot(x="x", y="y", hue="cat", data=data_tmp, palette=palette, style='cat',  dashes=dashes)
ax = sns.scatterplot(x="point", y="y", hue="cat", data=data_tmp[data_tmp.point == data_tmp.x], palette=palette,
                     legend=False)

ax.set_title('title')
ax.set_ylabel('y label')
ax.set_xlabel('x label')
ax.legend(loc=(1.04, 0))
plt.show()

在此处输入图像描述

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM