[英]Is it possible to adjust the size of squares (cells) in Seaborn heatmap?
假設我有一個像這樣的熱圖 plot:
使用這些數據:
import numpy as np
import pandas as pd
arr = np.array([[ 2, 2, 2, 8, 7, 7, 6, 5, 2, 7, 7, 8, 7, 5, 6, 6, 6],
[ 8, 7, 5, 4, 4, 3, 9, 6, 7, 4, 3, 2, 8, 9, 3, 3, 3],
[ 1, 3, 2, 2, 2, 3, 5, 3, 3, 2, 3, 3, 4, 1, 10, 10, 10],
[ 3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 9, 9, 9],
[ 7, 6, 7, 6, 6, 6, 2, 2, 5, 6, 5, 4, 7, 9, 9, 9, 9],
[ 6, 7, 8, 4, 3, 4, 4, 8, 7, 3, 4, 5, 6, 3, 4, 4, 4],
[ 3, 1, 1, 9, 9, 9, 3, 1, 8, 9, 9, 9, 1, 6, 1, 1, 1],
[ 3, 3, 3, 5, 5, 5, 5, 1, 2, 5, 6, 5, 10, 8, 8, 8, 8],
[ 1, 1, 1, 2, 3, 2, 7, 3, 1, 3, 2, 2, 10, 8, 7, 7, 7],
[ 5, 5, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 5, 2, 7, 7, 7],
[ 7, 9, 10, 3, 4, 4, 8, 9, 9, 3, 4, 6, 2, 3, 2, 2, 2],
[ 5, 6, 7, 3, 3, 3, 3, 1, 4, 4, 3, 4, 9, 10, 2, 2, 2],
[ 4, 4, 3, 4, 4, 4, 3, 4, 3, 4, 4, 3, 2, 7, 10, 10, 10],
[ 2, 1, 1, 8, 8, 8, 1, 4, 2, 8, 8, 8, 4, 1, 5, 5, 5],
[ 9, 9, 8, 8, 8, 8, 5, 6, 8, 8, 8, 5, 1, 5, 2, 2, 2],
[ 5, 5, 5, 5, 5, 5, 4, 2, 1, 5, 5, 4, 6, 5, 5, 5, 5],
[ 8, 8, 9, 10, 10, 10, 6, 7, 6, 10, 10, 10, 3, 7, 4, 4, 4],
[ 9, 8, 10, 5, 7, 7, 10, 10, 9, 6, 5, 6, 5, 6, 3, 3, 3],
[10, 9, 9, 7, 6, 5, 10, 10, 9, 8, 7, 8, 3, 10, 8, 8, 8],
[10, 10, 8, 10, 10, 10, 2, 5, 10, 10, 10, 9, 7, 9, 3, 3, 3],
[ 4, 4, 5, 3, 2, 2, 9, 8, 4, 2, 2, 3, 4, 4, 5, 5, 5],
[ 4, 4, 4, 7, 5, 6, 4, 4, 4, 5, 6, 7, 10, 2, 8, 8, 8],
[ 7, 8, 6, 6, 8, 8, 7, 9, 8, 7, 8, 7, 9, 8, 6, 6, 6],
[ 8, 7, 7, 7, 7, 7, 8, 9, 5, 7, 7, 7, 5, 7, 1, 1, 1],
[ 1, 2, 3, 1, 1, 1, 9, 7, 7, 1, 1, 1, 9, 3, 4, 4, 4],
[ 2, 5, 6, 1, 1, 2, 7, 5, 6, 1, 2, 2, 8, 4, 1, 1, 1],
[10, 10, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 3, 10, 7, 7, 7],
[ 6, 3, 4, 9, 9, 9, 8, 7, 5, 9, 9, 10, 1, 2, 10, 10, 10],
[ 9, 10, 10, 9, 9, 9, 1, 8, 10, 9, 9, 9, 8, 4, 9, 9, 9]])
columns = ["feature1", "feature2", "feature3", "feature4", "feature5", "feature6", "feature7", "feature8", "feature9", "feature10", "feature11", "feature12", "feature13", "feature14", "feature15", "feature16", "feature17"]
indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS',
'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK',
'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']
df = pd.DataFrame(arr, columns=columns, index=indexes)
使用此代碼:
import seaborn as sns
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10,10), dpi=600)
a = sns.heatmap(df, annot=True, cmap="RdBu_r", square=True, ax=ax)
plt.show()
我想根據它的值調整每個單元格的大小,我的意思是,值為 1 的方形單元格應該小於具有更高值的單元格!
例子:
請注意,此示例與之前的熱圖 plot 的值沒有嚴格相關。 我只是提供了一個示例來說明我的意思,即根據其值調整每個方形單元格的大小。
(這篇文章詳細闡述了@mwaskom 的優秀解決方案,適用於給定的 dataframe。)
對於大多數 seaborn 功能,將 dataframe 設置為“長格式”會有所幫助。
這是一個示例,說明如何將 dataframe 轉換為長格式以獲取例如sns.relplot
或sns.scatterplot
使用的格式。 可能,從用於創建 pivot 表的原始 dataframe 開始會更容易。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
arr = np.array([[2, 2, 2, 8, 7, 7, 6, 5, 2, 7, 7, 8, 7, 5, 6, 6, 6], [8, 7, 5, 4, 4, 3, 9, 6, 7, 4, 3, 2, 8, 9, 3, 3, 3], [1, 3, 2, 2, 2, 3, 5, 3, 3, 2, 3, 3, 4, 1, 10, 10, 10], [3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 9, 9, 9], [7, 6, 7, 6, 6, 6, 2, 2, 5, 6, 5, 4, 7, 9, 9, 9, 9], [6, 7, 8, 4, 3, 4, 4, 8, 7, 3, 4, 5, 6, 3, 4, 4, 4], [3, 1, 1, 9, 9, 9, 3, 1, 8, 9, 9, 9, 1, 6, 1, 1, 1], [3, 3, 3, 5, 5, 5, 5, 1, 2, 5, 6, 5, 10, 8, 8, 8, 8], [1, 1, 1, 2, 3, 2, 7, 3, 1, 3, 2, 2, 10, 8, 7, 7, 7], [5, 5, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 5, 2, 7, 7, 7], [7, 9, 10, 3, 4, 4, 8, 9, 9, 3, 4, 6, 2, 3, 2, 2, 2], [5, 6, 7, 3, 3, 3, 3, 1, 4, 4, 3, 4, 9, 10, 2, 2, 2], [4, 4, 3, 4, 4, 4, 3, 4, 3, 4, 4, 3, 2, 7, 10, 10, 10], [2, 1, 1, 8, 8, 8, 1, 4, 2, 8, 8, 8, 4, 1, 5, 5, 5], [9, 9, 8, 8, 8, 8, 5, 6, 8, 8, 8, 5, 1, 5, 2, 2, 2], [5, 5, 5, 5, 5, 5, 4, 2, 1, 5, 5, 4, 6, 5, 5, 5, 5], [8, 8, 9, 10, 10, 10, 6, 7, 6, 10, 10, 10, 3, 7, 4, 4, 4], [9, 8, 10, 5, 7, 7, 10, 10, 9, 6, 5, 6, 5, 6, 3, 3, 3], [10, 9, 9, 7, 6, 5, 10, 10, 9, 8, 7, 8, 3, 10, 8, 8, 8], [10, 10, 8, 10, 10, 10, 2, 5, 10, 10, 10, 9, 7, 9, 3, 3, 3], [4, 4, 5, 3, 2, 2, 9, 8, 4, 2, 2, 3, 4, 4, 5, 5, 5], [4, 4, 4, 7, 5, 6, 4, 4, 4, 5, 6, 7, 10, 2, 8, 8, 8], [7, 8, 6, 6, 8, 8, 7, 9, 8, 7, 8, 7, 9, 8, 6, 6, 6], [8, 7, 7, 7, 7, 7, 8, 9, 5, 7, 7, 7, 5, 7, 1, 1, 1], [1, 2, 3, 1, 1, 1, 9, 7, 7, 1, 1, 1, 9, 3, 4, 4, 4], [2, 5, 6, 1, 1, 2, 7, 5, 6, 1, 2, 2, 8, 4, 1, 1, 1], [10, 10, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 3, 10, 7, 7, 7], [6, 3, 4, 9, 9, 9, 8, 7, 5, 9, 9, 10, 1, 2, 10, 10, 10], [9, 10, 10, 9, 9, 9, 1, 8, 10, 9, 9, 9, 8, 4, 9, 9, 9]])
columns = [f"feature{i}" for i in range(1, 18)]
indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']
df = pd.DataFrame(arr, columns=columns, index=indexes)
df.index.name = 'Ticker'
df_long = df.reset_index().melt(id_vars='Ticker', var_name='Feature', value_name='Value')
sns.set_style('darkgrid')
g = sns.relplot(data=df_long, x="Feature", y="Ticker", size="Value", hue="Value",
marker="s", sizes=(20, 200), palette="blend:limegreen,orange", height=8, aspect=1.1)
g.ax.tick_params(axis='x', labelrotation=45)
g.ax.set_facecolor('aliceblue')
g.ax.grid(color='red', lw=1)
g.fig.subplots_adjust(left=0.1, bottom=0.15)
plt.show()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.