繁体   English   中英

是否可以调整 Seaborn 热图中的正方形(单元格)的大小?

[英]Is it possible to adjust the size of squares (cells) in Seaborn heatmap?

假设我有一个像这样的热图 plot:

在此处输入图像描述

使用这些数据:

import numpy as np
import pandas as pd


arr = np.array([[ 2,  2,  2,  8,  7,  7,  6,  5,  2,  7,  7,  8,  7,  5,  6,  6,  6],
       [ 8,  7,  5,  4,  4,  3,  9,  6,  7,  4,  3,  2,  8,  9,  3,  3,  3],
       [ 1,  3,  2,  2,  2,  3,  5,  3,  3,  2,  3,  3,  4,  1, 10, 10, 10],
       [ 3,  2,  4,  1,  1,  1,  2,  2,  1,  1,  1,  1,  2,  1,  9,  9,  9],
       [ 7,  6,  7,  6,  6,  6,  2,  2,  5,  6,  5,  4,  7,  9,  9,  9,  9],
       [ 6,  7,  8,  4,  3,  4,  4,  8,  7,  3,  4,  5,  6,  3,  4,  4,  4],
       [ 3,  1,  1,  9,  9,  9,  3,  1,  8,  9,  9,  9,  1,  6,  1,  1,  1],
       [ 3,  3,  3,  5,  5,  5,  5,  1,  2,  5,  6,  5, 10,  8,  8,  8,  8],
       [ 1,  1,  1,  2,  3,  2,  7,  3,  1,  3,  2,  2, 10,  8,  7,  7,  7],
       [ 5,  5,  2,  2,  2,  1,  1,  3,  3,  2,  1,  1,  5,  2,  7,  7,  7],
       [ 7,  9, 10,  3,  4,  4,  8,  9,  9,  3,  4,  6,  2,  3,  2,  2,  2],
       [ 5,  6,  7,  3,  3,  3,  3,  1,  4,  4,  3,  4,  9, 10,  2,  2,  2],
       [ 4,  4,  3,  4,  4,  4,  3,  4,  3,  4,  4,  3,  2,  7, 10, 10, 10],
       [ 2,  1,  1,  8,  8,  8,  1,  4,  2,  8,  8,  8,  4,  1,  5,  5,  5],
       [ 9,  9,  8,  8,  8,  8,  5,  6,  8,  8,  8,  5,  1,  5,  2,  2,  2],
       [ 5,  5,  5,  5,  5,  5,  4,  2,  1,  5,  5,  4,  6,  5,  5,  5,  5],
       [ 8,  8,  9, 10, 10, 10,  6,  7,  6, 10, 10, 10,  3,  7,  4,  4,  4],
       [ 9,  8, 10,  5,  7,  7, 10, 10,  9,  6,  5,  6,  5,  6,  3,  3,  3],
       [10,  9,  9,  7,  6,  5, 10, 10,  9,  8,  7,  8,  3, 10,  8,  8,  8],
       [10, 10,  8, 10, 10, 10,  2,  5, 10, 10, 10,  9,  7,  9,  3,  3,  3],
       [ 4,  4,  5,  3,  2,  2,  9,  8,  4,  2,  2,  3,  4,  4,  5,  5,  5],
       [ 4,  4,  4,  7,  5,  6,  4,  4,  4,  5,  6,  7, 10,  2,  8,  8,  8],
       [ 7,  8,  6,  6,  8,  8,  7,  9,  8,  7,  8,  7,  9,  8,  6,  6,  6],
       [ 8,  7,  7,  7,  7,  7,  8,  9,  5,  7,  7,  7,  5,  7,  1,  1,  1],
       [ 1,  2,  3,  1,  1,  1,  9,  7,  7,  1,  1,  1,  9,  3,  4,  4,  4],
       [ 2,  5,  6,  1,  1,  2,  7,  5,  6,  1,  2,  2,  8,  4,  1,  1,  1],
       [10, 10,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10,  3, 10,  7,  7,  7],
       [ 6,  3,  4,  9,  9,  9,  8,  7,  5,  9,  9, 10,  1,  2, 10, 10, 10],
       [ 9, 10, 10,  9,  9,  9,  1,  8, 10,  9,  9,  9,  8,  4,  9,  9,   9]])

columns = ["feature1", "feature2", "feature3", "feature4", "feature5", "feature6", "feature7", "feature8", "feature9", "feature10", "feature11", "feature12", "feature13", "feature14", "feature15", "feature16", "feature17"]

indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS',
       'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK',
       'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']

df = pd.DataFrame(arr, columns=columns, index=indexes)

使用此代码:

import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10,10), dpi=600)
a = sns.heatmap(df, annot=True, cmap="RdBu_r", square=True, ax=ax)
plt.show()

我想根据它的值调整每个单元格的大小,我的意思是,值为 1 的方形单元格应该小于具有更高值的单元格!
例子: 在此处输入图像描述

请注意,此示例与之前的热图 plot 的值没有严格相关。 我只是提供了一个示例来说明我的意思,即根据其值调整每个方形单元格的大小。

这是您可以使用scatterplotrelplot完成的事情:

flights = sns.load_dataset("flights")
g = sns.relplot(
    data=flights,
    x="year", y="month", size="passengers", hue="passengers",
    marker="s", sizes=(40, 400), palette="blend:b,r",
)

在此处输入图像描述

(这篇文章详细阐述了@mwaskom 的优秀解决方案,适用于给定的 dataframe。)

对于大多数 seaborn 功能,将 dataframe 设置为“长格式”会有所帮助。

这是一个示例,说明如何将 dataframe 转换为长格式以获取例如sns.relplotsns.scatterplot使用的格式。 可能,从用于创建 pivot 表的原始 dataframe 开始会更容易。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

arr = np.array([[2, 2, 2, 8, 7, 7, 6, 5, 2, 7, 7, 8, 7, 5, 6, 6, 6], [8, 7, 5, 4, 4, 3, 9, 6, 7, 4, 3, 2, 8, 9, 3, 3, 3], [1, 3, 2, 2, 2, 3, 5, 3, 3, 2, 3, 3, 4, 1, 10, 10, 10], [3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 9, 9, 9], [7, 6, 7, 6, 6, 6, 2, 2, 5, 6, 5, 4, 7, 9, 9, 9, 9], [6, 7, 8, 4, 3, 4, 4, 8, 7, 3, 4, 5, 6, 3, 4, 4, 4], [3, 1, 1, 9, 9, 9, 3, 1, 8, 9, 9, 9, 1, 6, 1, 1, 1], [3, 3, 3, 5, 5, 5, 5, 1, 2, 5, 6, 5, 10, 8, 8, 8, 8], [1, 1, 1, 2, 3, 2, 7, 3, 1, 3, 2, 2, 10, 8, 7, 7, 7], [5, 5, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 5, 2, 7, 7, 7], [7, 9, 10, 3, 4, 4, 8, 9, 9, 3, 4, 6, 2, 3, 2, 2, 2], [5, 6, 7, 3, 3, 3, 3, 1, 4, 4, 3, 4, 9, 10, 2, 2, 2], [4, 4, 3, 4, 4, 4, 3, 4, 3, 4, 4, 3, 2, 7, 10, 10, 10], [2, 1, 1, 8, 8, 8, 1, 4, 2, 8, 8, 8, 4, 1, 5, 5, 5], [9, 9, 8, 8, 8, 8, 5, 6, 8, 8, 8, 5, 1, 5, 2, 2, 2], [5, 5, 5, 5, 5, 5, 4, 2, 1, 5, 5, 4, 6, 5, 5, 5, 5], [8, 8, 9, 10, 10, 10, 6, 7, 6, 10, 10, 10, 3, 7, 4, 4, 4], [9, 8, 10, 5, 7, 7, 10, 10, 9, 6, 5, 6, 5, 6, 3, 3, 3], [10, 9, 9, 7, 6, 5, 10, 10, 9, 8, 7, 8, 3, 10, 8, 8, 8], [10, 10, 8, 10, 10, 10, 2, 5, 10, 10, 10, 9, 7, 9, 3, 3, 3], [4, 4, 5, 3, 2, 2, 9, 8, 4, 2, 2, 3, 4, 4, 5, 5, 5], [4, 4, 4, 7, 5, 6, 4, 4, 4, 5, 6, 7, 10, 2, 8, 8, 8], [7, 8, 6, 6, 8, 8, 7, 9, 8, 7, 8, 7, 9, 8, 6, 6, 6], [8, 7, 7, 7, 7, 7, 8, 9, 5, 7, 7, 7, 5, 7, 1, 1, 1], [1, 2, 3, 1, 1, 1, 9, 7, 7, 1, 1, 1, 9, 3, 4, 4, 4], [2, 5, 6, 1, 1, 2, 7, 5, 6, 1, 2, 2, 8, 4, 1, 1, 1], [10, 10, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 3, 10, 7, 7, 7], [6, 3, 4, 9, 9, 9, 8, 7, 5, 9, 9, 10, 1, 2, 10, 10, 10], [9, 10, 10, 9, 9, 9, 1, 8, 10, 9, 9, 9, 8, 4, 9, 9, 9]])
columns = [f"feature{i}" for i in range(1, 18)]
indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']
df = pd.DataFrame(arr, columns=columns, index=indexes)
df.index.name = 'Ticker'

df_long = df.reset_index().melt(id_vars='Ticker', var_name='Feature', value_name='Value')
sns.set_style('darkgrid')
g = sns.relplot(data=df_long, x="Feature", y="Ticker", size="Value", hue="Value",
                marker="s", sizes=(20, 200), palette="blend:limegreen,orange", height=8, aspect=1.1)
g.ax.tick_params(axis='x', labelrotation=45)
g.ax.set_facecolor('aliceblue')
g.ax.grid(color='red', lw=1)

g.fig.subplots_adjust(left=0.1, bottom=0.15)
plt.show()

sns.relplot 使用长格式数据框

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM