[英]How to make correct covid tracking time series plot with matplotlib in python?
[英]how to get correct correlation plot on time series data with matplotlib/seaborn?
我有每周收集的時間序列數據,我想在其中查看其兩列的相關性。 為此,我可以找到兩列之間的相關性,並希望了解滾動相關性每年的變化情況。 我目前的方法工作正常,但我需要在進行滾動相關之前對兩列進行規范化並制作一條線 plot。 在我目前的嘗試中,我不知道如何顯示 3 年、5 年的滾動相關性。 誰能建議在matplotlib
中這樣做的可能想法?
當前嘗試:
這是我目前的嘗試:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
dataPath="https://gist.github.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa#file-sample_data-csv"
def ts_corr_plot(dataPath, roll_window=4):
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = pd.DatetimeIndex(df['date']).week
df['year'] = pd.DatetimeIndex(df['date']).year
df['week'] = df['date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
df = df.loc[x.index]
return df[:, 1].corr(df[:, 2])
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
fig, ax = plt.subplots(figsize=(7, 4), dpi=144)
sns.lineplot(x='week', y='corr', hue='year', data=df,alpha=.8)
plt.show()
plt.close
更新:
我想查看不同時間 window 的滾動相關性,例如:
plt_1 = ts_corr_plot(dataPath, roll_window=4)
plt_2 = ts_corr_plot(dataPath, roll_window=12)
plt_3 = ts_corr_plot(dataPath, roll_window=24)
我需要在圖中添加 3 年、5 年的滾動相關性,但我找不到更好的方法。 誰能指出我如何為時間序列數據制作滾動相關線 plot? 如何改進當前的嘗試? 任何想法?
所需 plot
這是我想要獲得的預期 plot :
在esaborn中自定義圖例是很辛苦的,所以我在matplotlib中創建了代碼。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
dataPath="https://gist.githubusercontent.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa/raw/414a2fc2988fcf0b8e6911d77cccfbeb4b9e9664/sample_data.csv"
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = df['Date'].dt.isocalendar().week
df['year'] = df['Date'].dt.year
df['week'] = df['Date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
dfc = df.loc[x.index]
tmp = dfc.iloc[:, [1,2]].corr()
tmp = tmp.iloc[0,1]
return tmp
roll_window=4
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
df3 = df.copy() # three year
df3['corr3'] = df3['year'].rolling(156).apply(find_corr) # 3 year = 52 week x 3 year = 156
fig, ax = plt.subplots(figsize=(12, 4), dpi=144)
cmap = plt.get_cmap("tab10")
for i,y in enumerate(df['year'].unique()):
tmp = df[df['year'] == y]
ax.plot(tmp['week'], tmp['corr'], color=cmap(i), label=y)
for i,y in enumerate(df['year'].unique()):
tmp = df3[df3['year'] == y]
if tmp['corr3'].notnull().all():
ax.plot(tmp['week'], tmp['corr3'], color=cmap(i), lw=3, linestyle='--', label=str(y)+' 3 year avg')
ax.grid(axis='both')
ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0), borderaxespad=1)
plt.show()
# plt.close
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.