简体   繁体   中英

Decision Tree Of SkLearn: Overfitting or Bug?

I'm analyzing the training error and validation error of my decision tree model using the tree package of sklearn.

#compute the rms error
def compute_error(x, y, model):
 yfit = model.predict(x.toarray())
 return np.mean(y != yfit) 

def drawLearningCurve(model,xTrain, yTrain, xTest, yTest):
 sizes = np.linspace(2, 25000, 50).astype(int)
 train_error = np.zeros(sizes.shape)
 crossval_error = np.zeros(sizes.shape)

 for i,size in enumerate(sizes):

  model = model.fit(xTrain[:size,:].toarray(),yTrain[:size])

  #compute the validation error
  crossval_error[i] = compute_error(xTest,yTest,model)

  #compute the training error
  train_error[i] = compute_error(xTrain[:size,:],yTrain[:size],model)

from sklearn import tree
clf = tree.DecisionTreeClassifier()
drawLearningCurve(clf, xtr, ytr, xte, yte)

The problem is (I don't know whether it is a problem) that if I give a decision tree as model to the function drawLearningCurve , I receive the result of the training error as 0.0 in each loop. Is it related to the nature of my dataset, or that of the tree package of sklearn? Or is there something else wrong?

PS: training error is absolutely not 0.0 at other models like naive-bayes, knn or ann.

The commends give some pretty useful directions. I'd just like to add the parameter you might want to tweak is called max_depth .

What worries me more is your compute_error function is odd. The fact that you get an error of 0 says your classifier makes no errors on the training set. However, if it did make any mistakes your error function won't tell you that.

import numpy as np
np.mean([0,0,0,0] != [0,0,0,0]) # perfect match, error is 0
0.0

np.mean([0,0,0,0] != [1, 1, 1, 1]) # 100% wrong answers
1.0

np.mean([0,0,0,0] != [1, 1, 1, 0]) # 75% wrong answers
1.0

np.mean([0,0,0,0] != [1, 1, 0, 0]) # 50% wrong answers
1.0

np.mean([0,0,0,0] != [1, 1, 2, 2]) # 50% wrong answers
1.0

What you want is np.sum(y != yfit) , or even better, one of the error functions that come with sklearn, such as accuracy_score .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM