簡體   English   中英

Pandas dataframe 散點圖 plot 以 2 級多索引為軸

[英]Pandas dataframe scatter plot with 2-level Multiindex as axes

我有一個帶有 2 級多索引的 dataframe df 我想要一個散點圖 plot,x 軸為 0 級,y 軸為 1 級,所有滿足條件的組合的散點,比如在特定列'col'中具有非零值。

import matplotlib.pyplot as plt
from itertools import product
import numpy as np

lengths = [3, 2]
df_index = pd.MultiIndex.from_product([list(product([-1,1], repeat=li)) for li in lengths], names=['level1', 'level2'])

df_cols = ['cols']
df = pd.DataFrame([[0.] * len(df_cols)] * len(df_index), index=df_index, columns=df_cols)
df['cols'] = np.random.randint(0, 2, size = len(df))
df

產生以下形式的 dataframe

                       cols
level1       level2        
(-1, -1, -1) (-1, -1)     0
             (-1, 1)      0
             (1, -1)      0
             (1, 1)       0
(-1, -1, 1)  (-1, -1)     1
             (-1, 1)      0
             (1, -1)      1
             (1, 1)       1
(-1, 1, -1)  (-1, -1)     0
             (-1, 1)      0
             (1, -1)      0
             (1, 1)       0
(-1, 1, 1)   (-1, -1)     0
             (-1, 1)      0
             (1, -1)      1
             (1, 1)       0
(1, -1, -1)  (-1, -1)     0
             (-1, 1)      0
             (1, -1)      1
             (1, 1)       1
(1, -1, 1)   (-1, -1)     0
             (-1, 1)      1
             (1, -1)      1
             (1, 1)       0

...

現在,我想要一個散點圖 plot,x 軸上的 level1 索引和 y 軸上的 level2 索引,這樣對於 cols(x,y).= 0 的每個 (x,y) 都有一個點。

讓我們首先創建一個具有 2 級多索引的示例 dataframe:

import pandas as pd
import numpy as np
iterables = [[1, 2, 3, 4], [0,1, 2, 3, 4,5]]
my_multiindex=pd.MultiIndex.from_product(iterables, names=['first', 'second'])
series1 = pd.Series(np.random.randn(24), index=my_multiindex)
series2 = pd.Series(np.random.randn(24), index=my_multiindex)
df=pd.DataFrame({'col1':series1,'col2':series2})

現在,讓我們獲取滿足給定條件的索引值:

index_values=df[df.col1<0].index.values

然后我們分開xy坐標:

xs=[a[0] for a in index_values]
ys=[a[1] for a in index_values]

然后我們 plot:

from matplotlib import pyplot as plt
plt.scatter(xs,ys)

如果您希望散點的大小反映實際值,您可以使用:

column_values=abs(df[df.col1<0].col1.values)
plt.scatter(xs,ys,s=column_values*10)

編輯以反映已編輯的問題

您只需要將xsys轉換為字符串。 我還使用了一個大圖,以便軸刻度標簽不重疊:

plt.figure(figsize=(10,10))
plt.scatter([str(a) for a in xs],[str(a) for a in ys])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM