[英]Matplotlib: Plot Data and then Time Series Predictions
我正在使用 matplotlib 来显示股票价格随时间的变化。 我想关注过去 90 天,然后预测接下来的 14 天。 我有过去 90 天的数据和我的预测,但我想用不同的颜色绘制我的预测,所以很明显它们是不同的。
我该怎么做?
如果我只是在我的代码中添加第二个plot()
调用,预测将从与我的 90 天数据相同的点开始并叠加,这不是我想要的。
现在我正在这样做:
df[-90:]["price"].plot()
plt.show()
谢谢!
希望这是你想要的:
import pandas as pd
import numpy as np; np.random.seed(1)
import matplotlib.pyplot as plt
datelist = pd.date_range(pd.datetime(2018, 1, 1), periods=104)
df = pd.DataFrame(np.cumsum(np.random.randn(104)),
columns=['price'], index=datelist)
plt.plot(df[:90].index, df[:90].values)
plt.plot(df[90:].index, df[90:].values)
# If you don't like the break in the graph, change 90 to 89 in the above line
plt.gcf().autofmt_xdate()
plt.show()
简短的回答:
使用pd.merge()
并充分利用两个不同系列中的缺失值,得到两条不同颜色的线条。 对于您使用的数据帧索引类型(日期、整数或字符串),此建议将非常灵活。 这是你会得到的:
长答案:
关于...的详细信息
我想关注过去 90 天,然后预测接下来的 14 天。
...我将假设您使用的是带有每日索引的数据框。 我还假设您知道 90 天的数据集和 14 天的数据集的索引值。
这是一个包含 104 个观察值(随机数据)的数据框:
片段 1:
import pandas as pd
import numpy as np
np.random.seed(12)
rows = 104
df = pd.DataFrame(np.random.randint(-4,5,size=(rows, 1)), columns=['data'])
datelist = pd.date_range(pd.datetime(2018, 1, 1).strftime('%Y-%m-%d'), periods=rows).tolist()
df['dates'] = datelist
df = df.set_index(['dates'])
df.index = pd.to_datetime(df.index)
df = df.cumsum()
df.plot()
情节 1:
为了复制您的设置,我已将数据帧拆分为两个不同的帧,其中包含 90 次观察(价格)和 14 天(预测)。 这样,您将拥有两个不同的数据集,但关联的索引将是连续的 - 我假设这是您的实际情况。
片段 2:
df_90 = df[:90].copy(deep = True)
df_14 = df[-14:].copy(deep = True)
df_90.columns = ['price']
df_14.columns = ['predictions']
df_90.plot()
df_14.plot()
情节 2:
现在您可以将它们合并在一起,这样您将获得一个包含两列(数据和预测)的数据框。 当然,您最终会得到一些丢失的数据,但这正是您在绘制时会为您提供两条颜色不同的线的原因。
片段 3:
df_all = pd.merge(df_90, df_14, how = 'outer', left_index=True, right_index=True)
df_all.plot()
情节 3:
我希望建议的解决方案符合您的实际情况。 让我知道有关索引的详细信息是否会成为问题,我也会考虑一下。
以下是简单复制粘贴的完整代码:
import pandas as pd
import numpy as np
np.random.seed(12)
rows = 104
df = pd.DataFrame(np.random.randint(-4,5,size=(rows, 1)), columns=['data'])
datelist = pd.date_range(pd.datetime(2018, 1, 1).strftime('%Y-%m-%d'), periods=rows).tolist()
df['dates'] = datelist
df = df.set_index(['dates'])
df.index = pd.to_datetime(df.index)
df = df.cumsum()
df.plot()
df_90 = df[:90].copy(deep = True)
df_14 = df[-14:].copy(deep = True)
df_90.columns = ['price']
df_14.columns = ['predictions']
df_90.plot()
df_14.plot()
df_all = pd.merge(df_90, df_14, how = 'outer', left_index=True, right_index=True)
df_all.plot()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.