简体   繁体   中英

Python: Function for Multiple Regression

I have the following dataframe:

import pandas as pd
from sklearn import linear_model

import statsmodels.api as sm

Stock_Market = {'Year': [2017,2017,2017,2017,2017,2017,2017,2017,2017,2017,2017,2017,2016,2016,2016,2016,2016,2016,2016,2016,2016,2016,2016,2016],
                'Month': [12, 11,10,9,8,7,6,5,4,3,2,1,12,11,10,9,8,7,6,5,4,3,2,1],
                'Interest_Rate': [2.75,2.5,2.5,2.5,2.5,2.5,2.5,2.25,2.25,2.25,2,2,2,1.75,1.75,1.75,1.75,1.75,1.75,1.75,1.75,1.75,1.75,1.75],
                'Unemployment_Rate': [5.3,5.3,5.3,5.3,5.4,5.6,5.5,5.5,5.5,5.6,5.7,5.9,6,5.9,5.8,6.1,6.2,6.1,6.1,6.1,5.9,6.2,6.2,6.1],
                'Stock_Index_Price': [1464,1394,1357,1293,1256,1254,1234,1195,1159,1167,1130,1075,1047,965,943,958,971,949,884,866,876,822,704,719]        
                }

df = pd.DataFrame(Stock_Market,columns=['Year','Month','Interest_Rate','Unemployment_Rate','Stock_Index_Price'])

Currently, I'm able to perform a multiple regression of 'Interest_Rate' & 'Unemployment_Rate' on 'Stock_Index_Price' using the following function:

def perform_regression_multiple(y, x1, x2=""):
    test = df[[y, x1, x2]].reset_index(drop=True)
    
    X = test[[x1, x2]]
    Y = test[[y]]
    
    regr = linear_model.LinearRegression()
    regr.fit(X, Y)

    model = sm.OLS(Y, X).fit()
    predictions = model.predict(X) 

    print_model = model.summary()
    print(print_model)
    
#===========================================================================

perform_regression_multiple('Stock_Index_Price', 'Interest_Rate', 'Unemployment_Rate')

However, when I try to perform a linear regression (eg by using 'Interest_Rate' as the only explanatory variable) using the above function, then I receive the following error message:

perform_regression_multiple('Stock_Index_Price', 'Interest_Rate')

KeyError: "[''] not in index"

Obviously, both x1 and x2 need to be specified; otherwise it won't work. How am I supposed to modify the function in a way that allows me to specify the number of explanatory variables? The objective would be to extend the regression model by additional factors.

Many thanks for any suggestions

Take a look at how you are defining your function:

def perform_regression_multiple(y, x1, x2=""):

And then how you are calling it:

perform_regression_multiple('Stock_Index_Price', 'Interest_Rate')

With that call, you are telling the function that y="Stock Index Price" , x1="Interest Rate" and x2="" , which is the default value.

On the very first line of your function, you are taking the x2 column:

test = df[[y, x1, x2]].reset_index(drop=True)

That you have defined as being "", and the error is saying that the column with name "" does not exist.

If you want to be able to perform a regression with one or two variables, make this:

def perform_regression_multiple(y, x1, x2=None):
    if x2:
        test = df[[y, x1, x2]].reset_index(drop=True)
    
        X = test[[x1, x2]]
    else:
        test = df[[y, x1]].reset_index(drop=True)
        
        X = test[[x1]]
    Y = test[[y]]
    
    regr = linear_model.LinearRegression()
    regr.fit(X, Y)

    model = sm.OLS(Y, X).fit()
    predictions = model.predict(X) 

    print_model = model.summary()
    print(print_model)

You can leave the empty string as well and the if would still work the same way.

Even better, taking in account that for selecting columns in pandas and returing a dataframe you have to pass a list, you can do this, passing a list to the x_variables argument (even if it's a list of just one item):

def perform_regression_multiple(y: str, x_variables: list):
        columns = [y] + x_variables
        test = df[columns].reset_index(drop=True)
            
        X = test[x_variables]
        Y = test[[y]]
        
        regr = linear_model.LinearRegression()
        regr.fit(X, Y)
    
        model = sm.OLS(Y, X).fit()
        predictions = model.predict(X) 
    
        print_model = model.summary()
        print(print_model)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM