簡體   English   中英

是否可以調整 Seaborn 熱圖中的正方形(單元格)的大小?

[英]Is it possible to adjust the size of squares (cells) in Seaborn heatmap?

假設我有一個像這樣的熱圖 plot:

在此處輸入圖像描述

使用這些數據:

import numpy as np
import pandas as pd


arr = np.array([[ 2,  2,  2,  8,  7,  7,  6,  5,  2,  7,  7,  8,  7,  5,  6,  6,  6],
       [ 8,  7,  5,  4,  4,  3,  9,  6,  7,  4,  3,  2,  8,  9,  3,  3,  3],
       [ 1,  3,  2,  2,  2,  3,  5,  3,  3,  2,  3,  3,  4,  1, 10, 10, 10],
       [ 3,  2,  4,  1,  1,  1,  2,  2,  1,  1,  1,  1,  2,  1,  9,  9,  9],
       [ 7,  6,  7,  6,  6,  6,  2,  2,  5,  6,  5,  4,  7,  9,  9,  9,  9],
       [ 6,  7,  8,  4,  3,  4,  4,  8,  7,  3,  4,  5,  6,  3,  4,  4,  4],
       [ 3,  1,  1,  9,  9,  9,  3,  1,  8,  9,  9,  9,  1,  6,  1,  1,  1],
       [ 3,  3,  3,  5,  5,  5,  5,  1,  2,  5,  6,  5, 10,  8,  8,  8,  8],
       [ 1,  1,  1,  2,  3,  2,  7,  3,  1,  3,  2,  2, 10,  8,  7,  7,  7],
       [ 5,  5,  2,  2,  2,  1,  1,  3,  3,  2,  1,  1,  5,  2,  7,  7,  7],
       [ 7,  9, 10,  3,  4,  4,  8,  9,  9,  3,  4,  6,  2,  3,  2,  2,  2],
       [ 5,  6,  7,  3,  3,  3,  3,  1,  4,  4,  3,  4,  9, 10,  2,  2,  2],
       [ 4,  4,  3,  4,  4,  4,  3,  4,  3,  4,  4,  3,  2,  7, 10, 10, 10],
       [ 2,  1,  1,  8,  8,  8,  1,  4,  2,  8,  8,  8,  4,  1,  5,  5,  5],
       [ 9,  9,  8,  8,  8,  8,  5,  6,  8,  8,  8,  5,  1,  5,  2,  2,  2],
       [ 5,  5,  5,  5,  5,  5,  4,  2,  1,  5,  5,  4,  6,  5,  5,  5,  5],
       [ 8,  8,  9, 10, 10, 10,  6,  7,  6, 10, 10, 10,  3,  7,  4,  4,  4],
       [ 9,  8, 10,  5,  7,  7, 10, 10,  9,  6,  5,  6,  5,  6,  3,  3,  3],
       [10,  9,  9,  7,  6,  5, 10, 10,  9,  8,  7,  8,  3, 10,  8,  8,  8],
       [10, 10,  8, 10, 10, 10,  2,  5, 10, 10, 10,  9,  7,  9,  3,  3,  3],
       [ 4,  4,  5,  3,  2,  2,  9,  8,  4,  2,  2,  3,  4,  4,  5,  5,  5],
       [ 4,  4,  4,  7,  5,  6,  4,  4,  4,  5,  6,  7, 10,  2,  8,  8,  8],
       [ 7,  8,  6,  6,  8,  8,  7,  9,  8,  7,  8,  7,  9,  8,  6,  6,  6],
       [ 8,  7,  7,  7,  7,  7,  8,  9,  5,  7,  7,  7,  5,  7,  1,  1,  1],
       [ 1,  2,  3,  1,  1,  1,  9,  7,  7,  1,  1,  1,  9,  3,  4,  4,  4],
       [ 2,  5,  6,  1,  1,  2,  7,  5,  6,  1,  2,  2,  8,  4,  1,  1,  1],
       [10, 10,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10,  3, 10,  7,  7,  7],
       [ 6,  3,  4,  9,  9,  9,  8,  7,  5,  9,  9, 10,  1,  2, 10, 10, 10],
       [ 9, 10, 10,  9,  9,  9,  1,  8, 10,  9,  9,  9,  8,  4,  9,  9,   9]])

columns = ["feature1", "feature2", "feature3", "feature4", "feature5", "feature6", "feature7", "feature8", "feature9", "feature10", "feature11", "feature12", "feature13", "feature14", "feature15", "feature16", "feature17"]

indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS',
       'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK',
       'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']

df = pd.DataFrame(arr, columns=columns, index=indexes)

使用此代碼:

import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10,10), dpi=600)
a = sns.heatmap(df, annot=True, cmap="RdBu_r", square=True, ax=ax)
plt.show()

我想根據它的值調整每個單元格的大小,我的意思是,值為 1 的方形單元格應該小於具有更高值的單元格!
例子: 在此處輸入圖像描述

請注意,此示例與之前的熱圖 plot 的值沒有嚴格相關。 我只是提供了一個示例來說明我的意思,即根據其值調整每個方形單元格的大小。

這是您可以使用scatterplotrelplot完成的事情:

flights = sns.load_dataset("flights")
g = sns.relplot(
    data=flights,
    x="year", y="month", size="passengers", hue="passengers",
    marker="s", sizes=(40, 400), palette="blend:b,r",
)

在此處輸入圖像描述

(這篇文章詳細闡述了@mwaskom 的優秀解決方案,適用於給定的 dataframe。)

對於大多數 seaborn 功能,將 dataframe 設置為“長格式”會有所幫助。

這是一個示例,說明如何將 dataframe 轉換為長格式以獲取例如sns.relplotsns.scatterplot使用的格式。 可能,從用於創建 pivot 表的原始 dataframe 開始會更容易。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

arr = np.array([[2, 2, 2, 8, 7, 7, 6, 5, 2, 7, 7, 8, 7, 5, 6, 6, 6], [8, 7, 5, 4, 4, 3, 9, 6, 7, 4, 3, 2, 8, 9, 3, 3, 3], [1, 3, 2, 2, 2, 3, 5, 3, 3, 2, 3, 3, 4, 1, 10, 10, 10], [3, 2, 4, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 9, 9, 9], [7, 6, 7, 6, 6, 6, 2, 2, 5, 6, 5, 4, 7, 9, 9, 9, 9], [6, 7, 8, 4, 3, 4, 4, 8, 7, 3, 4, 5, 6, 3, 4, 4, 4], [3, 1, 1, 9, 9, 9, 3, 1, 8, 9, 9, 9, 1, 6, 1, 1, 1], [3, 3, 3, 5, 5, 5, 5, 1, 2, 5, 6, 5, 10, 8, 8, 8, 8], [1, 1, 1, 2, 3, 2, 7, 3, 1, 3, 2, 2, 10, 8, 7, 7, 7], [5, 5, 2, 2, 2, 1, 1, 3, 3, 2, 1, 1, 5, 2, 7, 7, 7], [7, 9, 10, 3, 4, 4, 8, 9, 9, 3, 4, 6, 2, 3, 2, 2, 2], [5, 6, 7, 3, 3, 3, 3, 1, 4, 4, 3, 4, 9, 10, 2, 2, 2], [4, 4, 3, 4, 4, 4, 3, 4, 3, 4, 4, 3, 2, 7, 10, 10, 10], [2, 1, 1, 8, 8, 8, 1, 4, 2, 8, 8, 8, 4, 1, 5, 5, 5], [9, 9, 8, 8, 8, 8, 5, 6, 8, 8, 8, 5, 1, 5, 2, 2, 2], [5, 5, 5, 5, 5, 5, 4, 2, 1, 5, 5, 4, 6, 5, 5, 5, 5], [8, 8, 9, 10, 10, 10, 6, 7, 6, 10, 10, 10, 3, 7, 4, 4, 4], [9, 8, 10, 5, 7, 7, 10, 10, 9, 6, 5, 6, 5, 6, 3, 3, 3], [10, 9, 9, 7, 6, 5, 10, 10, 9, 8, 7, 8, 3, 10, 8, 8, 8], [10, 10, 8, 10, 10, 10, 2, 5, 10, 10, 10, 9, 7, 9, 3, 3, 3], [4, 4, 5, 3, 2, 2, 9, 8, 4, 2, 2, 3, 4, 4, 5, 5, 5], [4, 4, 4, 7, 5, 6, 4, 4, 4, 5, 6, 7, 10, 2, 8, 8, 8], [7, 8, 6, 6, 8, 8, 7, 9, 8, 7, 8, 7, 9, 8, 6, 6, 6], [8, 7, 7, 7, 7, 7, 8, 9, 5, 7, 7, 7, 5, 7, 1, 1, 1], [1, 2, 3, 1, 1, 1, 9, 7, 7, 1, 1, 1, 9, 3, 4, 4, 4], [2, 5, 6, 1, 1, 2, 7, 5, 6, 1, 2, 2, 8, 4, 1, 1, 1], [10, 10, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 3, 10, 7, 7, 7], [6, 3, 4, 9, 9, 9, 8, 7, 5, 9, 9, 10, 1, 2, 10, 10, 10], [9, 10, 10, 9, 9, 9, 1, 8, 10, 9, 9, 9, 8, 4, 9, 9, 9]])
columns = [f"feature{i}" for i in range(1, 18)]
indexes = ['AAPL', 'AMGN', 'AXP', 'BA', 'CAT', 'CRM', 'CSCO', 'CVX', 'DIS', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'JPM', 'KO', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'V', 'VZ', 'WBA', 'WMT']
df = pd.DataFrame(arr, columns=columns, index=indexes)
df.index.name = 'Ticker'

df_long = df.reset_index().melt(id_vars='Ticker', var_name='Feature', value_name='Value')
sns.set_style('darkgrid')
g = sns.relplot(data=df_long, x="Feature", y="Ticker", size="Value", hue="Value",
                marker="s", sizes=(20, 200), palette="blend:limegreen,orange", height=8, aspect=1.1)
g.ax.tick_params(axis='x', labelrotation=45)
g.ax.set_facecolor('aliceblue')
g.ax.grid(color='red', lw=1)

g.fig.subplots_adjust(left=0.1, bottom=0.15)
plt.show()

sns.relplot 使用長格式數據框

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM