繁体   English   中英

如何使用 matplotlib/seaborn 在时间序列数据上获得正确的相关性 plot?

[英]how to get correct correlation plot on time series data with matplotlib/seaborn?

我有每周收集的时间序列数据,我想在其中查看其两列的相关性。 为此,我可以找到两列之间的相关性,并希望了解滚动相关性每年的变化情况。 我目前的方法工作正常,但我需要在进行滚动相关之前对两列进行规范化并制作一条线 plot。 在我目前的尝试中,我不知道如何显示 3 年、5 年的滚动相关性。 谁能建议在matplotlib中这样做的可能想法?

当前尝试

这是我目前的尝试:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

dataPath="https://gist.github.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa#file-sample_data-csv"

def ts_corr_plot(dataPath, roll_window=4):
    df = pd.read_csv(dataPath)
    df['Date'] = pd.to_datetime(df['Date'])
    df['week'] = pd.DatetimeIndex(df['date']).week
    df['year'] = pd.DatetimeIndex(df['date']).year
    df['week'] = df['date'].dt.strftime('%W').astype('uint8')
    
    def find_corr(x):
        df = df.loc[x.index]
        return df[:, 1].corr(df[:, 2])
    
    df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
    fig, ax = plt.subplots(figsize=(7, 4), dpi=144)
    sns.lineplot(x='week', y='corr', hue='year', data=df,alpha=.8)
    plt.show()
    plt.close

更新

我想查看不同时间 window 的滚动相关性,例如:

plt_1 = ts_corr_plot(dataPath, roll_window=4)
plt_2 = ts_corr_plot(dataPath, roll_window=12)
plt_3 = ts_corr_plot(dataPath, roll_window=24)

我需要在图中添加 3 年、5 年的滚动相关性,但我找不到更好的方法。 谁能指出我如何为时间序列数据制作滚动相关线 plot? 如何改进当前的尝试? 任何想法?

所需 plot

这是我想要获得的预期 plot :

在此处输入图像描述

在esaborn中自定义图例是很辛苦的,所以我在matplotlib中创建了代码。

  1. 修正了相关系数的计算方法。 您的代码给了我一个错误,所以如果我错了,请纠正我。
  2. 折线图的颜色似乎是从想要的图形颜色来的画面颜色,所以我使用了matplotlib中定义的画面的10 colors。
  3. 为了计算 3 年的相关系数,我使用了 156 个线单位,即 3 年的每周数据。 如有错误请更正此逻辑。
  4. 我分别在循环过程中创建 4 周和 3 年的图表。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

dataPath="https://gist.githubusercontent.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa/raw/414a2fc2988fcf0b8e6911d77cccfbeb4b9e9664/sample_data.csv"

df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = df['Date'].dt.isocalendar().week
df['year'] = df['Date'].dt.year
df['week'] = df['Date'].dt.strftime('%W').astype('uint8')

def find_corr(x):
    dfc = df.loc[x.index]
    tmp = dfc.iloc[:, [1,2]].corr()
    tmp = tmp.iloc[0,1]
    return tmp

roll_window=4
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
df3 = df.copy() # three year
df3['corr3'] = df3['year'].rolling(156).apply(find_corr) # 3 year = 52 week x 3 year = 156 

fig, ax = plt.subplots(figsize=(12, 4), dpi=144)
cmap = plt.get_cmap("tab10")

for i,y in enumerate(df['year'].unique()):
    tmp = df[df['year'] == y]
    ax.plot(tmp['week'], tmp['corr'], color=cmap(i), label=y)

for i,y in enumerate(df['year'].unique()):
    tmp = df3[df3['year'] == y]
    if tmp['corr3'].notnull().all():
        ax.plot(tmp['week'], tmp['corr3'], color=cmap(i), lw=3, linestyle='--', label=str(y)+' 3 year avg')

ax.grid(axis='both')
ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0), borderaxespad=1)
plt.show()
# plt.close

在此处输入图像描述

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM