[英]How to keep the ones in the Diagonal of sns.heatmap?
我想 plot 與 sns.heatmap 的相關矩陣並有一些問題。 這是我的代碼:
plt.figure(figsize=(8,8)) mask =np.zeros_like(data.corr()) mask[np.triu_indices_from(mask)] = True sns.heatmap(data.corr(), mask=mask, linewidth=1, annot=True, fmt=".2f",cmap='coolwarm',vmin=-1, vmax=1) plt.show()
這就是我得到的:[相關矩陣][1] [1]: https://i.stack.imgur.com/DX2oN.png \
現在我有一些問題:
1)我怎樣才能把那些放在對角線上?
2)如何改變x軸的position?
3) 我希望顏色條從 1 變為 -1,但代碼不工作
我希望有人能幫幫忙。
謝謝
我認為您必須檢查data.corr()
,因為您的代碼正確並且可以診斷(請參閱下文)。 一個問題是:您使用np.triu
但顯示的圖片顯示np.tirl
。
這是我測試過的代碼-對角線在這里:
N = 5
A = np.arange(N*N).reshape(N,N)
B = np.tril(A)
mask =np.zeros_like(A)
mask[np.triu_indices_from(mask)] = True
print('A'); print(A); print()
print('tril(A)'); print(B); print()
print('mask'); print(mask); print()
給
A
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
tril(A)
[[ 0 0 0 0 0]
[ 5 6 0 0 0]
[10 11 12 0 0]
[15 16 17 18 0]
[20 21 22 23 24]]
mask
[[1 1 1 1 1]
[0 1 1 1 1]
[0 0 1 1 1]
[0 0 0 1 1]
[0 0 0 0 1]]
編輯:補充
您可以重新調整面膜,例如
C = A *mask
D = np.where(C > 1, 1,C)
print('D'); print(D)
給
D
[[0 1 1 1 1]
[0 1 1 1 1]
[0 0 1 1 1]
[0 0 0 1 1]
[0 0 0 0 1]]
D的對角線的第一個元素現在為零,因為A的對角線的第一個元素也為零。
編輯:補充2
F = np.tril(A,-1)
E = np.eye(N)
G = E + F
print('F'); print(F); print()
print('E'); print(E); print()
print('G'); print(G); print()
給
F
[[ 0 0 0 0 0]
[ 5 0 0 0 0]
[10 11 0 0 0]
[15 16 17 0 0]
[20 21 22 23 0]]
E
[[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 1.]]
G
[[ 1. 0. 0. 0. 0.]
[ 5. 1. 0. 0. 0.]
[10. 11. 1. 0. 0.]
[15. 16. 17. 1. 0.]
[20. 21. 22. 23. 1.]]
mask[np.triu_indices_from(mask)]
將定義三角形(包括對角線)
mask[np.eye(mask.shape[0], dtype=bool)]
將定義對角線。
如果將它們放在一起,則可以獨立控制它們。 (請注意,您需要在對角線之前設置三角形)。
def plot_correlation_matrix(df, remove_diagonal=True, remove_triangle=False, **kwargs):
corr = df.corr()
# Apply mask
mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = remove_triangle
mask[np.eye(mask.shape[0], dtype=bool)] = remove_diagonal
# Plot
# plt.figure(figsize=(8,8))
sns.heatmap(corr, mask=mask, **kwargs)
plt.show()
所以此命令將生成矩陣,移除上三角,但保留對角線:
plot_correlation_matrix(df[colunas_notas], remove_diagonal=False, remove_triangle=True)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.