[英]sklearn support vector machine is not learning
I am trying to classify images using sklearn
's svm.SVC
classifier, but it's not learning, after training I got 0.1 accuracy (there are 10 classes, so 0.1 accuracy is the same as a random guess) 我正在尝试使用
sklearn
的svm.SVC
分类器对图像进行分类,但它不是学习的,经过训练后我获得了0.1的准确度(有10个类别,所以0.1的准确度与随机猜测相同)
I am using the CIFAR-10 datatset. 我正在使用CIFAR-10数据集。 10000 images that are represented as 3072
uint8
s. 表示为3072
uint8
s的10000张图像。 The first 1024 are the red pixels, the second 1024 are the green pixels and the thirst 1024 are the blue pixels. 前1024个是红色像素,第二个1024是绿色像素,口渴1024是蓝色像素。
Each image also has a label, which is a number 0-9 每个图像还带有一个标签,该标签为数字0-9
Here is my code: 这是我的代码:
import numpy as np
from sklearn import preprocessing, svm
import pandas as pd
import pickle
from sklearn.externals import joblib
train_data = pickle.load(open('data_batch_1','rb'), encoding='latin1')
test_data = pickle.load(open('test_batch','rb'), encoding='latin1')
X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'])
y_test = np.array(test_data['labels'])
clf = svm.SVC(verbose=True)
clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test)
joblib.dump(clf, 'Cifar-10-clf.pickle')
print(accuracy)
Does anyone know what my problem could be or can point me to resources to solve this? 有谁知道我的问题可能是什么,或者可以指出我的资源来解决这个问题?
I'm not sure but I think that you need to tune the parameters of SVC. 我不确定,但我认为您需要调整SVC的参数。
I tested some parameters for learning then I got an 0.318
accuracy. 我测试了一些学习参数,然后获得了
0.318
准确度。
here is code: 这是代码:
# coding: utf-8
import numpy as np
from sklearn import preprocessing, svm
import cPickle
train_data = cPickle.load(open('data/data_batch_1', 'rb'))
test_data = cPickle.load(open('data/test_batch', 'rb'))
X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'][:1000])
y_test = np.array(test_data['labels'][:1000])
clf = svm.SVC(kernel='linear', C=10, gamma=0.01)
clf.fit(X_train, y_train)
accuracy = clf.score(X_test, y_test)
print "Accuracy: ", accuracy
And I recommend grid search function
for auto tuning the hyper-parameters. 并且我建议使用
grid search function
来自动调整超参数。
This is public documents about tuning the hyper-parameters
in scikit-learn 这是有关
tuning the hyper-parameters
scikit-learn中tuning the hyper-parameters
公共文档
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.