I have data that I have pre-grouped. Specifically they are PR-curves for 3 different classes and I want to plot them on the same axes:
import numpy as np
data_groups = {
'ap=0.16: cat_3 (4/19)': {
'precision': np.array([0. , 0. , 0. , 0. , 0.2 ,
0.16666667, 0.14285714, 0.25 , 0.22222222, 0.2 ,
0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
0.21052632], dtype=np.float64),
'recall': np.array([0. , 0. , 0. , 0. , 0.25, 0.25, 0.25, 0.5 , 0.5 , 0.5 , 0.5 ,
0.5 , 0.5 , 0.5 , 0.5 , 1. ], dtype=np.float64),
},
'ap=0.20: cat_1 (3/19)': {
'precision': np.array([0. , 0.5 , 0.33333333, 0.25 , 0.2 ,
0.16666667, 0.14285714, 0.25 , 0.22222222, 0.2 ,
0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
0.15789474], dtype=np.float64),
'recall': np.array([0. , 0.33333333, 0.33333333, 0.33333333, 0.33333333,
0.33333333, 0.33333333, 0.66666667, 0.66666667, 0.66666667,
0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667,
1. ], dtype=np.float64),
},
'ap=0.54: cat_2 (8/19)': {
'precision': np.array([0. , 0.5 , 0.33333333, 0.5 , 0.6 ,
0.66666667, 0.71428571, 0.75 , 0.66666667, 0.6 ,
0.63636364, 0.58333333, 0.53846154, 0.5 , 0.46666667,
0.42105263], dtype=np.float64),
'recall': np.array([0. , 0.125, 0.125, 0.25 , 0.375, 0.5 , 0.625, 0.75 , 0.75 ,
0.75 , 0.875, 0.875, 0.875, 0.875, 0.875, 1. ], dtype=np.float64),
},
}
I would like to use seaborn to plot these multiple lines in a single plot, but to do so I seem to need to transform this grouped data into a single long-form pandas table.
import pandas as pd
longform = []
for key, subdata in data_groups.items():
subdata = pd.DataFrame.from_dict(subdata)
subdata['label'] = key
longform.append(subdata)
data = pd.concat(longform)
Which effectively duplicates this "label" attribute for each item in the list:
recall precision label
0 0.000000 0.000000 ap=0.54: cat_2 (8/19)
1 0.125000 0.500000 ap=0.54: cat_2 (8/19)
2 0.125000 0.333333 ap=0.54: cat_2 (8/19)
...
0 0.000000 0.000000 ap=0.20: cat_1 (3/19)
1 0.333333 0.500000 ap=0.20: cat_1 (3/19)
2 0.333333 0.333333 ap=0.20: cat_1 (3/19)
3 0.333333 0.250000 ap=0.20: cat_1 (3/19)
...
0 0.000000 0.000000 ap=0.16: cat_3 (4/19)
1 0.000000 0.000000 ap=0.16: cat_3 (4/19)
2 0.000000 0.000000 ap=0.16: cat_3 (4/19)
At which point I can plot it:
import seaborn as sns
sns.lineplot(
data=data, x='recall', y='precision',
hue='label', style='label')
But I was wondering if there was a more efficient way to send the pre-grouped data into seaborn. I would like to avoid duplication the "label" attribute and I imagine it must effectively be inverting the pd.concat
operation I just performed.
In the data structures accepted by seaborn ( https://seaborn.pydata.org/tutorial/data_structure.html ) they only mention this long-form (which I understand pretty well) and wide-form data (which makes much less sense to me).
This pre-grouped data isn't a wide-form variant right? I just want to verify that performing the extra concat is currently the only way to do this.
You don't have to send the entire data to seaborn at once. You can plot line by line, and they will still appear on the same plot. Seaborn can handle well with numpy array (long-form), so you can send each item to plotting separately and it still works:
from matplotlib import pyplot as plt
import seaborn as sns
for key, subdata in data_groups.items():
sns.lineplot(x=subdata['recall'], y=subdata['precision'], label=key)
plt.show()
result:
Of course you need to take care of extra styling, like legend position, confidence interval and etc - but essentially, it's plotting directly each group without direct conversation to a dataframe.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.