简体   繁体   中英

error in python program. “expected 2D array but got 1D array instead”

I am trying to predict the price as well as plot to visualize the data. But there is an error that I am not able to figure it out.

dates=[]
prices=[]

def getdata(filename):
    with open(filename,'r') as csvfile:
        csvFilereader=csv.reader(csvfile)
        next(csvFilereader)
        for row in csvFilereader:

            dates.append(int(row[0].split('-')[0]))
            prices.append(float(row[1]))
    return
def predicted_price(dates, prices, x):

    dates=np.reshape(dates,len(dates),1)


    svr_linear= SVR(kernel='linear', C=1e3)
    svr_poly= SVR(kernel='poly', C=1e3, degree=2)
    svr_rbf= SVR(kernel='rbf', C=1e3, gamma=0.1)

    svr_linear.fit(dates,prices)
    svr_poly.fit(dates,prices)
    svr_rbf.fit(dates,prices)

    plt.scatter(dates,prices, color='black', label='Data')
    plt.plot(dates, svr.rbf.predict(dates), color='red', label='RBF Model')
    plt.plot(dates, svr.poly.predict(dates), color='blue', label='Poly Model')
    plt.plot(dates, svr.linear.predict(dates), color='green', label='Linera Model')

    plt.xlabel('Dates')
    plt.ylabel('Prices')
    plt.title('Regression')

    plt.legend()
    plt.show()

    return svr_rbf.predict(x)[0], svr_linerar.predict(x)[0], svr_poly(x)[0]


getdata('D:\\android\\trans1.csv')


predicted_prices=predicted_price(dates,prices,10)
print(predicted_prices)

Here is the error that I am getting:

Expected 2D array, got 1D array instead:
array=[19102018. 19102018. 19102018. ... 22102018. 20102018. 23102018.].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

Changing predicted_price:

(dates,prices,10)

to

([dates,prices,10])

Gives this error:

predicted_price() missing 2 required positional arguments: 'prices' and 'x'

Here is the image of data:

日期和价格数据

This code has at least 3 issues:

  • getdata does not return anything. It only works because dates and prices are global. Move both of them in getdata and return dates, prices
  • SVR is not imported (sklearn I guess)
  • What the error message tells you: dates = dates.reshape(-1, 1)

Sub-question about parameters

Changing predicted_price:

(dates,prices,10) to

([dates,prices,10]) Gives this error:

predicted_price() missing 2 required positional arguments: 'prices' and 'x'

When you write [dates,prices,10] you construct a single list. This single list is what you pass to the function. But the function expects 3 parameters, not one. Hence call it like predicted_price(dates,prices,10) .

Another note: The braces (...) belong to the function, not to the data. This is important, because

predicted_price(dates,prices,10)

is different from

predicted_price((dates,prices,10))

The first one is correct, the second one constructs a tuple and passes it to predicted_price .

In case you can make a minimal complete example including some data, you might want to ask for feedback on codereview.stackexchange.com

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM