简体   繁体   中英

How to run a simple linear regression on grouped values in Python?

ID ANGLE SPEED
A -10 25
B 55 45
B 16 56
A 30 63
C -15 52
C 2 78
B -5 65
D 65 50
D 35 88
D 26 75
A 12 53
D 45 91
C 32 86
C 18 23
B 56 64
B 49 20
A 11 65

Above is a data snippet that I am looking at. I would like to run a basic linear regression with ANGLE as the predictor and Speed the target variable. I am having trouble summarizing a predicted value based on the ID group and ideally I'd like something like this:

ID PREDICTED_SPEED
A 32
B 45
C 48
D 27

I have been using this:

def model(df):
  y = df[['SPEED']].values
  X = df[['ANGLE']].values
  
  return np.squeeze(LinearRegression().fit(X, y).predict(X))

df.groupby('ID').apply(model,'Y',['X'])

Not having any luck, so any help would be appreciated. Thanks.

Please read this before the solution:

Are you sure you are solving a regression problem?
Because you make the prediction using the training data, the prediction is surely accurate because the model is fitted using the training data. You have a regression problem, when you are trying to predict continuos values using new samples that the model has never seen.
To do this, you can split the dataset into training and test set, use the former for training and the latter to obtain a prediction on unseen data. You can evaluate the goodness of your model using RSS or other metrics.

Solution:

I'm not sure about what you trying to do, I suppose you want to predict the mean speed for each group. I have a little bit modified the code for the sake of simplicity.

  • The model function takes only one parameter, you cannot pass more than one argument. I suggest to create another column where the prediction is going to be stored and call directly the function model.
  • You need to apply the mean function to obtain the average after groupby.
 df = pd.DataFrame(columns=['ID', 'ANGLE', 'SPEED']) def model(df): y = df[['SPEED']].values X = df[['ANGLE']].values return np.squeeze(LinearRegression().fit(X, y).predict(X)) df['PREDICTION'] = model(df) df.groupby('ID').mean()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM