While searching some tutorial on SVM
, I've found online - Support Vector Machine _ Illustration - the below code, which is however yielding a weird
chart. After debugging the code, I wonder if the cause lies on the Date
list, precisely:
dates.append(int(row[0].split('-')[0]))
which is static from my side (ie 2016) or if there is something else, although I am not seeing anything abnormal within the code.
EDIT
This deduction is coming from the syntax:
plt.scatter(dates, prices, color ='black', label ='Data');
plt.show()
yielding the vertical line, factually, whereas
dates.append(int(row[0].split('-')[0]))
is supposed, as described in the link and also reflected into the code, to convert each date YYYY-MM-DD
to a different integer value
EDIT (2)
Substituting dates.append(md.datestr2num(row[0]))
for
dates.append(int(row[0].split('-')[0]))
in the function get_data(filename)
does help!
import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
dates = []
prices = []
def get_data(filename):
with open(filename, 'r') as csvfile:
csvFileReader = csv.reader(csvfile)
next(csvFileReader)
for row in csvFileReader:
dates.append(int(row[0].split('-')[0]))
prices.append(float(row[6])) # from 1 i.e from Opening to closing price
return
def predict_prices(dates,prices,x):
dates = np.reshape(dates,(len(dates),1))
svr_lin = SVR(kernel = 'linear', C = 1e3)
svr_poly = SVR(kernel = 'poly', C = 1e3, degree = 2)
svr_rbf = SVR(kernel = 'rbf', C = 1e3, gamma = 0.1)
svr_lin.fit(dates,prices)
svr_poly.fit(dates,prices)
svr_rbf.fit(dates,prices)
plt.scatter(dates, prices, color ='black', label ='Data')
plt.plot(dates, svr_rbf.predict(dates), color ='red', label = 'RBF model')
plt.plot(dates, svr_rbf.predict(dates), color ='green', label = 'Linear model')
plt.plot(dates, svr_rbf.predict(dates), color ='blue', label = 'Polynomial model')
plt.xlabel('Date')
plt.ylabel('Price')
plt.title('Support Vector Regression')
plt.legend
plt.show()
return svr_rbf.predict(x)[0], svr_lin.predict(x)[0], svr_poly.predict(x)[0]
get_data('C:/local/ACA.csv')
predict_prices(dates, prices, 29)
Thanks in advance
get_data
creates two lists, dates
and prices
.
What does np.array(dates)
and np.array(prices)
produce? Shape and dtype? And since your plot shows only one date, we need to see the range of values of that array.
I edited your question trying make the function definitions correct. Make sure I did that right.
What does the date column in the csv
look like?
Looks like your dates
parsing does:
In [25]: txt = '2016-02-20'
In [26]: txt.split('-')
Out[26]: ['2016', '02', '20']
In [27]: int(txt.split('-')[0])
Out[27]: 2016
So you are grabbing just the year. That would account for the vertical scatter plot at
In [29]: 0.010+2.01599e3
Out[29]: 2016.0
I think this would be a better date conversion - to a np.datetime64
dtype.
In [28]: np.array([txt], dtype='datetime64[D]')
Out[28]: array(['2016-02-20'], dtype='datetime64[D]')
I have been working with that SVM code from a number of examples (Siraj, Chaitjo, Jaihad, and others)...and found out that the Date needs to be in DD-MM-YYYY format...so the data used is the day date...not the year date (As dark.vapor has described).
And the data can only be for 30 days...as seen in this code segment:
"predict_prices(dates, prices, 29)"
Otherwise using datafiles with multiple months (with repeating day numbers...eg 15 Jan and 15 Feb)...I get multiple prices plotted on each day instead of only one day price for each day.
Edit2: I played with varying the dataset and found that the data rows can be more than 29...as long as the date is just an integer sequence. I went up to 85 days (rows)...and they all plotted. So I am a bit confused as to what the "29" does in the above prediction code?
It would be nice to be able to use larger datafiles with multiple months...and select the date ranges I want to test for...but for now that's above my coding skills.
I'm just a novice coder so I hope this is accurate as this seems to work for me using the DD-MM-YYYY format which works fine and gives me a good clean plot.
Hope this helps, Robert
Edit: I just found a good article describing this code...which confirms the "day" parsing with the DD-MM-YYYY format...
https://github.com/mKausthub/stock-er
dates.append(int(row[0].split('-')[0])) "gets day of the month which is at index zero since dates are in the format [date]-[month]-[year]."
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.