简体   繁体   中英

seaborn two corner pairplot

I'd like to create a single pairplot using two corner pairplot. Using

import seaborn as sns; sns.set(style="ticks", color_codes=True)
iris = sns.load_dataset("iris")
g = sns.pairplot(iris, hue="species", corner=True)

I obtain a lower triangle of the grid. What I'd like to do is to put another pairplot on the upper (off-diagonal) part of the grid using a different value for hue.

import seaborn as sns; sns.set(style="ticks", color_codes=True)
iris = sns.load_dataset("iris")
iris['species'] = iris['species'].map({'setosa': 0, 
                                   'versicolor': 1, 
                                   'virginica': 2})

sns.pairplot(iris, hue="species", corner=True)
sns.pairplot(iris, hue="petal_length", corner=True)

Is there a way to plot on the upper triangle? Or join two different pairplot?

Thanks in advance

There's no way to plot on the upper triangle. What you could do, however, is to make two plots, minimum one of them with corner=False , and then add the lower triangle and diagonal axes from the corner plot to the full plot. This only makes sense, however, if the pairplot parameters for both plot are identical, otherwise (as in your example) the axes labels and the legend will be valid for one triangle only (unless you manually add a second legend and right and top axes to the upper triangle suplots but in this case it'll probably easier to roll your own from the very beginning).

Example (lower triangle and diagonal for odd rows of iris , upper triangle for even rows):

import matplotlib.pyplot as plt
import seaborn as sns; sns.set(style="ticks", color_codes=True)
iris = sns.load_dataset("iris")

pg1 = sns.pairplot(iris[1::2], hue="species", corner=True)
pg2 = sns.pairplot(iris[::2], hue="species", corner=False, diag_kind=None))

# remove lower triangle and diagonal from figure 2
for ax in pg2.fig.get_axes():
    if ax.get_geometry()[2] in [1,5,6,9,10,11,13,14,15,16]:
        ax.remove()

# add all axes from figure 1 (lower triangle and diagonal) to figure 2
for ax in pg1.fig.get_axes():
    ax.figure = pg2.fig # in the next step we can only add axes from the same figure
    pg2.fig.add_axes(ax)

# close figure 1 which is not needed anymore    
plt.close(pg1.fig)

在此处输入图像描述

To save some effort, the necessary code for finding all the indexes necessary is:

upper_triangular_list = []

k=10

for i in range(k+1):
    for c in range(i-1):
        upper_triangular_list.append((i-1)*k+c+1)

print(upper_triangular_list)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM