簡體   English   中英

如何使用 plot 在具有特定 colors 和每條線的線型的特定點上使用 seaborn 繪制線圖?

[英]How to plot a lineplot with dots on specific points with specific colors and linetypes for each line using seaborn?

我有以下 dataframe

import pandas as pd



 data_tmp = pd.DataFrame({'x': [0,14,28,42,56, 0,14,28,42,56],
                         'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
                         'cat': ['A','A','A','A','A','B','B','B','B','B'],
                         'color': ['#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#247AB2','#247AB2','#247AB2','#247AB2','#247AB2'],
                         'point': [14,14,14,14,14,28,28,28,28,28],
                         'linestyles':['-','-','-','-','-','--','--','--','--','--']})

我想為每只cat制作一個具有不同colorlinestyles樣式的線圖。 但我想為每cat提供特定的colordataframe linestyles定義。 最后,我想用相同的顏色標記每條線上的point s。

我只試過:

sns.lineplot(x="x", y="y", hue="cat", data=data_tmp)
sns.scatterplot(x="point",y="y",hue="cat", data=data_tmp[data_tmp.point==data_tmp.x])
plt.show()

有任何想法嗎?

也許你想直接使用 matplotlib ,比如

import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame({'x': [0,14,28,42,56, 0,14,28,42,56],
                   'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
                   'cat': ['A','A','A','A','A','B','B','B','B','B'],})


d = {"A" : {"color": '#B5D8F0', "markersize":  5, "linestyle": "-"},
     "B" : {"color": '#247AB2', "markersize": 10, "linestyle": "--"}}

for n, grp in df.groupby("cat"):
    plt.plot(grp.x, grp.y, marker="o", label=n, **d[n])

plt.legend()
plt.show()

在此處輸入圖像描述

這就是我能做到的。 您需要使用cat列來控制不同的 plot 參數(顏色、樣式、標記大小),然后創建映射對象(此處為 dicts)來告訴每個類別使用哪個參數值。 顏色很簡單。 線條樣式更難,因為 Seaborn 僅提供dashes作為可配置參數,需要在(segment, gap)的高級 Matplotlib 格式中給出。 function matplotlib.lines._get_dash_pattern將字符串值(例如-- )轉換為這種格式,盡管返回值需要小心處理。 對於標記大小,不幸的lineplot不提供隨類別更改標記大小的可能性(即使您可以更改標記樣式),因此您需要在頂部使用scatterplot 最后一位是圖例,您可能希望為第二個 plot 禁用它,以避免重復,但問題是第一個圖例中沒有標記。 如果這讓您感到困擾,您仍然可以手動編輯圖例。 總而言之,它可能看起來像這樣:

import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

# Converts a line style to a format acceptable by Seaborn
def get_dash_pattern(style):
    _, dash = mpl.lines._get_dash_pattern(style)
    return dash if dash else (None, None)

data_tmp = pd.DataFrame({
    'x': [0,14,28,42,56, 0,14,28,42,56],
    'y': [0, 0.003, 0.006, 0.008, 0.001, 0*2, 0.003*2, 0.006*2, 0.008*2, 0.001*2],
    'cat': ['A','A','A','A','A','B','B','B','B','B'],
    'color': ['#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0','#B5D8F0',
              '#247AB2','#247AB2','#247AB2','#247AB2','#247AB2'],
    'point': [14,14,14,14,14,28,28,28,28,28],
    'linestyles':['-','-','-','-','-','--','--','--','--','--']})
# Extract plot features as dicts
feats = (data_tmp[['cat', 'color', 'linestyles', 'point']]
         .set_index('cat').drop_duplicates().to_dict())
palette, dashes, sizes = feats['color'], feats['linestyles'], feats['point']
# Convert line styles to dashes
dashes = {k: get_dash_pattern(v) for k, v in dashes.items()}
# Lines
lines = sns.lineplot(x="x", y="y", hue="cat", style="cat", data=data_tmp,
                     palette=palette, dashes=dashes)
# Points
sns.scatterplot(x="x", y="y", hue="cat", size="cat", data=data_tmp,
                palette=palette, sizes=sizes, legend=False)
# Fix legend
for t, l in zip(lines.legend().get_texts(), lines.legend().get_lines()):
    l.set_marker('o')
    l.set_markersize(sizes.get(l.get_label(), 0) / t.get_fontsize())
plt.show()

Output:

最終情節

這是我在@jdehesa 的幫助下的解決方案

我還將圖例放在 plot 之外,並對標簽進行了一些拋光

def get_dash_pattern(style):
    _, dash = mpl.lines._get_dash_pattern(style)
    return dash if dash else (None, None)

palette = dict(zip(data_tmp.cat, data_tmp.color))
dashes = dict(zip(data_tmp.cat, data_tmp.linestyles))
dashes = {k: get_dash_pattern(v) for k, v in dashes.items()}

ax = sns.lineplot(x="x", y="y", hue="cat", data=data_tmp, palette=palette, style='cat',  dashes=dashes)
ax = sns.scatterplot(x="point", y="y", hue="cat", data=data_tmp[data_tmp.point == data_tmp.x], palette=palette,
                     legend=False)

ax.set_title('title')
ax.set_ylabel('y label')
ax.set_xlabel('x label')
ax.legend(loc=(1.04, 0))
plt.show()

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM