繁体   English   中英

Pandas / Statsmodel OLS预测未来价值

[英]Pandas/Statsmodel OLS predicting future values

我一直试图在我创建的模型中预测未来的值。 我在熊猫和statsmodels中尝试了两种OLS。 这是我在statsmodels中的内容:

import statsmodels.api as sm
endog = pd.DataFrame(dframe['monthly_data_smoothed8'])
smresults = sm.OLS(dframe['monthly_data_smoothed8'], dframe['date_delta']).fit()
sm_pred = smresults.predict(endog)
sm_pred

返回的数组长度等于原始数据帧中的记录数,但值不相同。 当我使用pandas执行以下操作时,我没有返回任何值。

from pandas.stats.api import ols
res1 = ols(y=dframe['monthly_data_smoothed8'], x=dframe['date_delta'])
res1.predict

(请注意,Pandas中的OLS没有.fit函数)有人可以了解我如何从大熊猫或statsmodel中的OLS模型获得未来的预测 - 我意识到我一定不能正确使用.predict而且我已经阅读人们已经遇到的其他多个问题,但它们似乎并不适用于我的案例。

编辑我认为定义的'endog'是不正确的 - 我应该传递我想要预测的值; 因此,我创建了一个超过最后记录值的12个期间的日期范围。 但是我仍然错过了一些错误:

matrices are not aligned

编辑这里是一个数据片段,数字的最后一列(红色)是日期增量,它是与第一个日期相差几个月:

month   monthly_data    monthly_data_smoothed5  monthly_data_smoothed8  monthly_data_smoothed12 monthly_data_smoothed3  date_delta
0   2011-01-31  3.711838e+11    3.711838e+11    3.711838e+11    3.711838e+11    3.711838e+11    0.000000
1   2011-02-28  3.776706e+11    3.750759e+11    3.748327e+11    3.746975e+11    3.755084e+11    0.919937
2   2011-03-31  4.547079e+11    4.127964e+11    4.083554e+11    4.059256e+11    4.207653e+11    1.938438
3   2011-04-30  4.688370e+11    4.360748e+11    4.295531e+11    4.257843e+11    4.464035e+11    2.924085

我认为你的问题是statsmodels默认情况下不会添加拦截,所以你的模型没有太大的优势。 在你的代码中解决它将是这样的:

dframe = pd.read_clipboard() # your sample data
dframe['intercept'] = 1
X = dframe[['intercept', 'date_delta']]
y = dframe['monthly_data_smoothed8']

smresults = sm.OLS(y, X).fit()

dframe['pred'] = smresults.predict()

另外,对于它的价值,我认为在处理DataFrames时,statsmodel公式api更好用,并且默认情况下添加一个截距(添加一个- 1来删除)。 见下文,它应该给出相同的答案。

import statsmodels.formula.api as smf

smresults = smf.ols('monthly_data_smoothed8 ~ date_delta', dframe).fit()

dframe['pred'] = smresults.predict()

编辑:

要预测未来的值,只需将新数据传递给.predict()例如,使用第一个模型:

In [165]: smresults.predict(pd.DataFrame({'intercept': 1, 
                                          'date_delta': [0.5, 0.75, 1.0]}))
Out[165]: array([  2.03927604e+11,   2.95182280e+11,   3.86436955e+11])

在截距上 - 数字1没有任何编码只是基于OLS的数学(截距完全类似于总是等于1的回归量),因此您可以将值从摘要中拉出来。 查看statsmodels 文档 ,添加拦截的另一种方法是:

X = sm.add_constant(X)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM