简体   繁体   中英

Using seaborn to plot pre-grouped line data

I have data that I have pre-grouped. Specifically they are PR-curves for 3 different classes and I want to plot them on the same axes:

import numpy as np
data_groups = {
    'ap=0.16: cat_3 (4/19)': {
        'precision': np.array([0.        , 0.        , 0.        , 0.        , 0.2       ,
                               0.16666667, 0.14285714, 0.25      , 0.22222222, 0.2       ,
                               0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
                               0.21052632], dtype=np.float64),
        'recall': np.array([0.  , 0.  , 0.  , 0.  , 0.25, 0.25, 0.25, 0.5 , 0.5 , 0.5 , 0.5 ,
                            0.5 , 0.5 , 0.5 , 0.5 , 1.  ], dtype=np.float64),
    },
    'ap=0.20: cat_1 (3/19)': {
        'precision': np.array([0.        , 0.5       , 0.33333333, 0.25      , 0.2       ,
                               0.16666667, 0.14285714, 0.25      , 0.22222222, 0.2       ,
                               0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
                               0.15789474], dtype=np.float64),
        'recall': np.array([0.        , 0.33333333, 0.33333333, 0.33333333, 0.33333333,
                            0.33333333, 0.33333333, 0.66666667, 0.66666667, 0.66666667,
                            0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667,
                            1.        ], dtype=np.float64),
    },
    'ap=0.54: cat_2 (8/19)': {
        'precision': np.array([0.        , 0.5       , 0.33333333, 0.5       , 0.6       ,
                               0.66666667, 0.71428571, 0.75      , 0.66666667, 0.6       ,
                               0.63636364, 0.58333333, 0.53846154, 0.5       , 0.46666667,
                               0.42105263], dtype=np.float64),
        'recall': np.array([0.   , 0.125, 0.125, 0.25 , 0.375, 0.5  , 0.625, 0.75 , 0.75 ,
                            0.75 , 0.875, 0.875, 0.875, 0.875, 0.875, 1.   ], dtype=np.float64),
    },
}

I would like to use seaborn to plot these multiple lines in a single plot, but to do so I seem to need to transform this grouped data into a single long-form pandas table.

    import pandas as pd
    longform = []
    for key, subdata in data_groups.items():
        subdata = pd.DataFrame.from_dict(subdata)
        subdata['label'] = key
        longform.append(subdata)
    data = pd.concat(longform)

Which effectively duplicates this "label" attribute for each item in the list:

      recall  precision                  label
0   0.000000   0.000000  ap=0.54: cat_2 (8/19)
1   0.125000   0.500000  ap=0.54: cat_2 (8/19)
2   0.125000   0.333333  ap=0.54: cat_2 (8/19)
...

0   0.000000   0.000000  ap=0.20: cat_1 (3/19)
1   0.333333   0.500000  ap=0.20: cat_1 (3/19)
2   0.333333   0.333333  ap=0.20: cat_1 (3/19)
3   0.333333   0.250000  ap=0.20: cat_1 (3/19)

...

0   0.000000   0.000000  ap=0.16: cat_3 (4/19)
1   0.000000   0.000000  ap=0.16: cat_3 (4/19)
2   0.000000   0.000000  ap=0.16: cat_3 (4/19)

At which point I can plot it:

    import seaborn as sns
    sns.lineplot(
        data=data, x='recall', y='precision',
        hue='label', style='label') 

But I was wondering if there was a more efficient way to send the pre-grouped data into seaborn. I would like to avoid duplication the "label" attribute and I imagine it must effectively be inverting the pd.concat operation I just performed.

In the data structures accepted by seaborn ( https://seaborn.pydata.org/tutorial/data_structure.html ) they only mention this long-form (which I understand pretty well) and wide-form data (which makes much less sense to me).

This pre-grouped data isn't a wide-form variant right? I just want to verify that performing the extra concat is currently the only way to do this.

You don't have to send the entire data to seaborn at once. You can plot line by line, and they will still appear on the same plot. Seaborn can handle well with numpy array (long-form), so you can send each item to plotting separately and it still works:

from matplotlib import pyplot as plt
import seaborn as sns

for key, subdata in data_groups.items():
    sns.lineplot(x=subdata['recall'], y=subdata['precision'], label=key)
    
plt.show()

result:

结果

Of course you need to take care of extra styling, like legend position, confidence interval and etc - but essentially, it's plotting directly each group without direct conversation to a dataframe.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM