简体   繁体   English

在 SARIMAX 预测数据中绘制置信区间

[英]Plotting confidence interval in SARIMAX prediction data

I am trying to plot confidence interval band along the predicted values off a SARIMAX model.我正在尝试沿着 SARIMAX 模型的预测值绘制置信区间带。

A SARIMAX model is fitted using this:使用以下方法拟合 SARIMAX 模型:

model=sm.tsa.statespace.SARIMAX(data_df['Net Sales'],order=(1, 1, 1),seasonal_order=(1,1,1,12))
results=model.fit()
print(results.summary())

To plot the predicted values I am using the following code:要绘制预测值,我使用以下代码:

fig, ax = plt.subplots(figsize=(15,5))
ax.ticklabel_format(useOffset=False, style='plain')
data_df['Net_Sales forecast'] = results.predict(start = 48, end = 60, dynamic= True)  
data_df[['Net Sales', 'Net_Sales forecast']].plot(ax=ax, color=['blue', 'orange'], marker='o', legend=True)

Output输出

I want to plot a confidence interval band of 95% around the forecast data.我想在预测数据周围绘制一个 95% 的置信区间带。 I have tried various ways but to no avail.我尝试了各种方法,但无济于事。

I understand that I can access the parameters for confidence interval in the result of SARIMAX model using the following.我知道我可以使用以下方法访问 SARIMAX 模型结果中的置信区间参数。

ci = results.conf_int(alpha=0.05)
ci

Returns:回报:

              0               1
ar.L1   -3.633910e-01   1.108174e+00
ma.L1   -1.253388e+00   2.229091e-01
ar.S.L12 -3.360182e+00  4.001006e+00
ma.S.L12 -4.078321e+00  3.517885e+00
sigma2  3.080743e+13    3.080743e+13

How do I incorporate this into the plot to show the confidence interval band?如何将其合并到图中以显示置信区间带?

The confidence intervals you show are actually for model parameters, not for predictions.您显示的置信区间实际上是用于模型参数,而不是用于预测。 Here is an example of how you can compute and plot confidence intervals around the predictions, borrowing a dataset used in the statsmodels docs .这是一个示例,说明如何计算和绘制围绕预测的置信区间,借用statsmodels 文档中使用的数据集。

Note: You'll need to be cautious about interpreting these confidence intervals.注意:您需要谨慎解释这些置信区间。 Here is a relevant page discussing what is actually implemented in statsmodels .这是一个相关页面,讨论了statsmodels中实际实现的内容。

import matplotlib.pyplot as plt
import pandas as pd
import statsmodels.api as sm
import requests
from io import BytesIO

# Get data
wpi1 = requests.get('https://www.stata-press.com/data/r12/wpi1.dta').content
data = pd.read_stata(BytesIO(wpi1))
data.index = data.t
# Set the frequency
data.index.freq='QS-OCT'

# Fit the model
model = sm.tsa.statespace.SARIMAX(data['wpi'], trend='c', order=(1,1,1))
results = model.fit(disp=False)

# Get predictions
# (can also utilize results.get_forecast(steps=n).summary_frame(alpha=0.05))
preds_df = (results
            .get_prediction(start='1991-01-01', end='1999-10-01')
            .summary_frame(alpha=0.05)
)
print(preds_df.head())
# wpi               mean   mean_se  mean_ci_lower  mean_ci_upper
# 1991-01-01  118.358860  0.725041     116.937806     119.779914
# 1991-04-01  120.340500  1.284361     117.823198     122.857802
# 1991-07-01  122.167206  1.865597     118.510703     125.823709
# 1991-10-01  123.858465  2.463735     119.029634     128.687296
# 1992-01-01  125.431312  3.070871     119.412517     131.450108

# Plot the training data, predicted means and confidence intervals
fig, ax = plt.subplots(figsize=(15,5))
ax = data['wpi'].plot(label='Training Data')
ax.set(
    title='True and Predicted Values, with Confidence Intervals',
    xlabel='Date',
    ylabel='Actual / Predicted Values'
)
preds_df['mean'].plot(ax=ax, style='r', label='Predicted Mean')
ax.fill_between(
    preds_df.index, preds_df['mean_ci_lower'], preds_df['mean_ci_upper'],
    color='r', alpha=0.1
)
legend = ax.legend(loc='upper left')
plt.show()

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM