简体   繁体   中英

how to plot 20 different scatter subplots together in a Facetgrid - Python (seaborn)

I have seen multiple other threads for seaborn subplots in a facetgrid but none for my particular situation.

I have 20 columns for housing data. I want to plot each variable vs. the variable 'SalePrice' in a 5 row by 4 matrix FacetGrid using seaborn.

Here is the list of columns I chose and have a current dataframe called 'train_df'.

train_cols_to_keep = ['1stFlrSF', '2ndFlrSF', 'Fireplaces', 'FullBath', 'GarageArea', 'GarageCars', 
'GarageYrBlt', 'GrLivArea', 'HalfBath', 'Id', 'LotArea', 'LotFrontage', 'MasVnrArea', 'OpenPorchSF', 
'OverallQual', 'TotalBsmtSF', 'TotRmsAbvGrd', 'WoodDeckSF', 'YearBuilt', 'YearRemodAdd', 
'SalePrice']

train_data_reduced = train_data[train_cols_to_keep]

train_df = train_data_reduced.fillna(train_data_reduced.mean())

I really have no idea how to do this and every example I have seen doesn't include this problem of plotting each column vs. one of the columns. Thanks. Also acceptable is using the rpy2 library for ggplots

I don't think there is a way to get the exact result you are trying to get using seaborn . The closest you can get is to use pairplot() but you'll get one row with 20 columns of axes

g = sns.pairplot(data=df,
                 x_vars=['1stFlrSF', '2ndFlrSF', 'Fireplaces', 'FullBath', 'GarageArea', 'GarageCars', 
'GarageYrBlt', 'GrLivArea', 'HalfBath', 'Id', 'LotArea', 'LotFrontage', 'MasVnrArea', 'OpenPorchSF', 
'OverallQual', 'TotalBsmtSF', 'TotRmsAbvGrd', 'WoodDeckSF', 'YearBuilt', 'YearRemodAdd'],
                 y_vars=['SalePrice'])

If you were to use matplotlib directly, you can get the desired result without too much hassle

fig, axs = plt.subplots(4,5, sharey=True)
for ax,col in zip(axs.flat, ['1stFlrSF', '2ndFlrSF', 'Fireplaces', 'FullBath', 'GarageArea', 'GarageCars', 
'GarageYrBlt', 'GrLivArea', 'HalfBath', 'Id', 'LotArea', 'LotFrontage', 'MasVnrArea', 'OpenPorchSF', 
'OverallQual', 'TotalBsmtSF', 'TotRmsAbvGrd', 'WoodDeckSF', 'YearBuilt', 'YearRemodAdd']):
    ax.scatter(df[col],df['SalePrice'])
    ax.set_xlabel(col)
    if ax.is_first_col():
        ax.set_ylabel('SalePrice')
fig.tight_layout()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM