简体   繁体   中英

Support Vector Machine Python 3.5.2

While searching some tutorial on SVM , I've found online - Support Vector Machine _ Illustration - the below code, which is however yielding a weird chart. After debugging the code, I wonder if the cause lies on the Date list, precisely:

dates.append(int(row[0].split('-')[0]))

which is static from my side (ie 2016) or if there is something else, although I am not seeing anything abnormal within the code.

EDIT

This deduction is coming from the syntax:

plt.scatter(dates, prices, color ='black', label ='Data'); 
plt.show()

yielding the vertical line, factually, whereas

dates.append(int(row[0].split('-')[0]))

is supposed, as described in the link and also reflected into the code, to convert each date YYYY-MM-DD to a different integer value

EDIT (2)

Substituting dates.append(md.datestr2num(row[0])) for

dates.append(int(row[0].split('-')[0])) in the function get_data(filename) does help!

在此处输入图片说明

import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt

dates = []
prices = []

def get_data(filename):
    with open(filename, 'r') as csvfile:
        csvFileReader = csv.reader(csvfile)
        next(csvFileReader)
        for row in csvFileReader:
            dates.append(int(row[0].split('-')[0]))
            prices.append(float(row[6]))  # from 1 i.e from Opening to closing price

    return

def predict_prices(dates,prices,x):
    dates = np.reshape(dates,(len(dates),1))
    svr_lin = SVR(kernel = 'linear', C = 1e3)
    svr_poly = SVR(kernel = 'poly', C = 1e3, degree = 2)
    svr_rbf = SVR(kernel = 'rbf',  C = 1e3, gamma = 0.1)  

    svr_lin.fit(dates,prices)
    svr_poly.fit(dates,prices)
    svr_rbf.fit(dates,prices)

    plt.scatter(dates, prices, color ='black', label ='Data')
    plt.plot(dates, svr_rbf.predict(dates), color ='red', label = 'RBF model')
    plt.plot(dates, svr_rbf.predict(dates), color ='green', label = 'Linear model')
    plt.plot(dates, svr_rbf.predict(dates), color ='blue', label = 'Polynomial model')
    plt.xlabel('Date')
    plt.ylabel('Price')
    plt.title('Support Vector Regression')
    plt.legend
    plt.show()
    return svr_rbf.predict(x)[0], svr_lin.predict(x)[0], svr_poly.predict(x)[0]

get_data('C:/local/ACA.csv')
predict_prices(dates, prices, 29)

Thanks in advance

get_data creates two lists, dates and prices .

What does np.array(dates) and np.array(prices) produce? Shape and dtype? And since your plot shows only one date, we need to see the range of values of that array.

I edited your question trying make the function definitions correct. Make sure I did that right.

What does the date column in the csv look like?

Looks like your dates parsing does:

In [25]: txt = '2016-02-20'

In [26]: txt.split('-')
Out[26]: ['2016', '02', '20']

In [27]: int(txt.split('-')[0])
Out[27]: 2016

So you are grabbing just the year. That would account for the vertical scatter plot at

In [29]: 0.010+2.01599e3
Out[29]: 2016.0

I think this would be a better date conversion - to a np.datetime64 dtype.

In [28]: np.array([txt], dtype='datetime64[D]')
Out[28]: array(['2016-02-20'], dtype='datetime64[D]')

I have been working with that SVM code from a number of examples (Siraj, Chaitjo, Jaihad, and others)...and found out that the Date needs to be in DD-MM-YYYY format...so the data used is the day date...not the year date (As dark.vapor has described).

And the data can only be for 30 days...as seen in this code segment:

"predict_prices(dates, prices, 29)"

Otherwise using datafiles with multiple months (with repeating day numbers...eg 15 Jan and 15 Feb)...I get multiple prices plotted on each day instead of only one day price for each day.

Edit2: I played with varying the dataset and found that the data rows can be more than 29...as long as the date is just an integer sequence. I went up to 85 days (rows)...and they all plotted. So I am a bit confused as to what the "29" does in the above prediction code?

It would be nice to be able to use larger datafiles with multiple months...and select the date ranges I want to test for...but for now that's above my coding skills.

I'm just a novice coder so I hope this is accurate as this seems to work for me using the DD-MM-YYYY format which works fine and gives me a good clean plot.

Hope this helps, Robert

Edit: I just found a good article describing this code...which confirms the "day" parsing with the DD-MM-YYYY format...

https://github.com/mKausthub/stock-er

dates.append(int(row[0].split('-')[0])) "gets day of the month which is at index zero since dates are in the format [date]-[month]-[year]."

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM