[英]How to make correct covid tracking time series plot with matplotlib in python?
[英]how to get correct correlation plot on time series data with matplotlib/seaborn?
我有每周收集的时间序列数据,我想在其中查看其两列的相关性。 为此,我可以找到两列之间的相关性,并希望了解滚动相关性每年的变化情况。 我目前的方法工作正常,但我需要在进行滚动相关之前对两列进行规范化并制作一条线 plot。 在我目前的尝试中,我不知道如何显示 3 年、5 年的滚动相关性。 谁能建议在matplotlib
中这样做的可能想法?
当前尝试:
这是我目前的尝试:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
dataPath="https://gist.github.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa#file-sample_data-csv"
def ts_corr_plot(dataPath, roll_window=4):
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = pd.DatetimeIndex(df['date']).week
df['year'] = pd.DatetimeIndex(df['date']).year
df['week'] = df['date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
df = df.loc[x.index]
return df[:, 1].corr(df[:, 2])
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
fig, ax = plt.subplots(figsize=(7, 4), dpi=144)
sns.lineplot(x='week', y='corr', hue='year', data=df,alpha=.8)
plt.show()
plt.close
更新:
我想查看不同时间 window 的滚动相关性,例如:
plt_1 = ts_corr_plot(dataPath, roll_window=4)
plt_2 = ts_corr_plot(dataPath, roll_window=12)
plt_3 = ts_corr_plot(dataPath, roll_window=24)
我需要在图中添加 3 年、5 年的滚动相关性,但我找不到更好的方法。 谁能指出我如何为时间序列数据制作滚动相关线 plot? 如何改进当前的尝试? 任何想法?
所需 plot
这是我想要获得的预期 plot :
在esaborn中自定义图例是很辛苦的,所以我在matplotlib中创建了代码。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
dataPath="https://gist.githubusercontent.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa/raw/414a2fc2988fcf0b8e6911d77cccfbeb4b9e9664/sample_data.csv"
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = df['Date'].dt.isocalendar().week
df['year'] = df['Date'].dt.year
df['week'] = df['Date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
dfc = df.loc[x.index]
tmp = dfc.iloc[:, [1,2]].corr()
tmp = tmp.iloc[0,1]
return tmp
roll_window=4
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
df3 = df.copy() # three year
df3['corr3'] = df3['year'].rolling(156).apply(find_corr) # 3 year = 52 week x 3 year = 156
fig, ax = plt.subplots(figsize=(12, 4), dpi=144)
cmap = plt.get_cmap("tab10")
for i,y in enumerate(df['year'].unique()):
tmp = df[df['year'] == y]
ax.plot(tmp['week'], tmp['corr'], color=cmap(i), label=y)
for i,y in enumerate(df['year'].unique()):
tmp = df3[df3['year'] == y]
if tmp['corr3'].notnull().all():
ax.plot(tmp['week'], tmp['corr3'], color=cmap(i), lw=3, linestyle='--', label=str(y)+' 3 year avg')
ax.grid(axis='both')
ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0), borderaxespad=1)
plt.show()
# plt.close
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.