简体   繁体   中英

Predicting with chi squared kernel for multilabel using sklearn

I'm trying to get predictions for an SVM using a precomputed chi-squared kernel. However, I am getting issues when trying to run clf.predict().

min_max_scaler = preprocessing.MinMaxScaler()
X_train_scaled = min_max_scaler.fit_transform(features_train)
X_test_scaled = min_max_scaler.transform(features_test)

K = chi2_kernel(X_train_scaled)
svm = SVC(kernel='precomputed', cache_size=1000).fit(K, labels_train)
y_pred_chi2 = svm.predict(X_test_scaled)

The error I am getting is the following:

ValueError: bad input shape (4627L, 20L)

I am guessing this issue is because of the multi-label, so I trained the classifier for only 1 category by doing the following:

svm = SVC(kernel='precomputed', cache_size=1000).fit(K, labels_train[:, 0])

However, when trying to run clf.predict(X_test_scaled), I get the error:

ValueError: X.shape[1] = 44604 should be equal to 4627, the number of samples at training time

Why does the test samples have to be the same number as the training samples?

Here is the shape of the relevant matrices (the features have 44604 dimensions and there are 20 categories):

X_train_scaled.shape    : (4627L, 44604L)
X_test_scaled.shape     : (4637L, 44604L)
K.shape                 : (4627L, 4627L)
labels_train.shape      : (4627L, 20L)

On a side note, is it normal that there is L next to the shape sizes of these matrices?

You need to give the predict function the kernel between the test data and the training data. The easiest way for that is to give a callable to the kernel parameter kernel=chi2_kernel . Using

K_test = chi2_kernel(X_test_scaled)

will not work.It needs to be

K_test = chi2_kernel(X_test_scaled, X_train_scaled)

The input to clf.predict() must also be passed to the chi2_kernel function.

K_test = chi2_kernel(X_test_scaled)
y_pred = svm.predict(K_test)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM