简体   繁体   中英

Scikit learn wrong predictions with SVC

I am trying to predict the MNIST ( http://pjreddie.com/projects/mnist-in-csv/ ) dataset with an SVM using the radial kernel. I want to train with few examples (eg 1000) and predict many more. The problem is that whenever I predict, the predictions are constant unless the indices of the test set coincide with those of the training set. That is, suppose I train with examples 1:1000 from my training examples. Then, the predictions are correct (ie the SVM does its best) for 1:1000 of my test set, but then I get the same output for the rest. If however I train with examples 2001:3000, then only the test examples corresponding to those rows in the test set are labeled correctly (ie not with the same constant). I am completely at a loss, and I think that there is some sort of bug, because the exact same code works just fine with LinearSVC, although evidently the accuracy of the method is lower.

First, I train with examples 501:1000 of training data:

# dat_train/test are pandas DFs corresponding to both MNIST datasets
dat_train = pd.read_csv('data/mnist_train.csv', header=None)
dat_test = pd.read_csv('data/mnist_train.csv', header=None)

svm = SVC(C=10.0)
idx = range(1000)
#idx = np.random.choice(range(len(dat_train)), size=1000, replace=False)
X_train = dat_train.iloc[idx,1:].reset_index(drop=True).as_matrix()
y_train = dat_train.iloc[idx,0].reset_index(drop=True).as_matrix()
X_test = dat_test.reset_index(drop=True).as_matrix()[:,1:]
y_test = dat_test.reset_index(drop=True).as_matrix()[:,0]
svm.fit(X=X_train[501:1000,:], y=y_train[501:1000])

Here you can see that about half the predictions are wrong

y_pred = svm.predict(X_test[:1000,:])
confusion_matrix(y_test[:1000], y_pred)

All wrong (ie constant)

y_pred = svm.predict(X_test[:500,:])
confusion_matrix(y_test[:500], y_pred)

This is what I would expect to see for all test data

y_pred = svm.predict(X_test[501:1000,:])
confusion_matrix(y_test[501:1000], y_pred)

You can check that all of the above are correct using LinearSVC!

The default kernel is RBF, in which case gamma matters. If gamma is not provided, it is auto by default, which is 1/n_features. You'd better run grid search to find the optimal parameters. Here I just illustrate the result is normal given suitable parameters.

In [120]: svm = SVC(C=1, gamma=0.0000001)

In [121]: svm.fit(X=X_train[501:1000,:], y=y_train[501:1000])
Out[121]:
SVC(C=1, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=1e-07, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

In [122]: y_pred = svm.predict(X_test[:1000,:])

In [123]: confusion_matrix(y_test[:1000], y_pred)
Out[123]:
array([[ 71,   0,   2,   0,   2,   9,   1,   0,   0,   0],
       [  0, 123,   0,   0,   0,   1,   1,   0,   1,   0],
       [  2,   5,  91,   1,   1,   1,   3,   7,   5,   0],
       [  0,   1,   4,  48,   0,  40,   1,   5,   7,   1],
       [  0,   0,   0,   0,  88,   2,   3,   2,   0,  15],
       [  1,   1,   1,   0,   2,  77,   0,   3,   1,   1],
       [  3,   0,   3,   0,   5,   4,  72,   0,   0,   0],
       [  0,   2,   3,   0,   3,   0,   1,  88,   1,   1],
       [  2,   0,   1,   2,   3,   9,   1,   4,  63,   4],
       [  0,   1,   0,   0,  16,   3,   0,  11,   1,  62]])

Finding good parameters for an SVC is an art in itself. Grid Search might help, better works some population based training like in this article - i recently tried it. If you let it run the same time, it has better results than GridSearch. If you let it run until the accuracy is the same, it is faster.

It also helps to make a graphic: let the x and y axis be C and gamma, and plot the prediction scores as color. Usually you will find kind of a V-Shape with the best training results at the point where the two lines meet. At the same time this point has low C-Values, too, which is desirable because C determines the runtime of the SVC: High C makes a long runtime.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM