简体   繁体   English

sklearn支持向量机未学习

[英]sklearn support vector machine is not learning

I am trying to classify images using sklearn 's svm.SVC classifier, but it's not learning, after training I got 0.1 accuracy (there are 10 classes, so 0.1 accuracy is the same as a random guess) 我正在尝试使用sklearnsvm.SVC分类器对图像进行分类,但它不是学习的,经过训练后我获得了0.1的准确度(有10个类别,所以0.1的准确度与随机猜测相同)

I am using the CIFAR-10 datatset. 我正在使用CIFAR-10数据集。 10000 images that are represented as 3072 uint8 s. 表示为3072 uint8 s的10000张图像。 The first 1024 are the red pixels, the second 1024 are the green pixels and the thirst 1024 are the blue pixels. 前1024个是红色像素,第二个1024是绿色像素,口渴1024是蓝色像素。

Each image also has a label, which is a number 0-9 每个图像还带有一个标签,该标签为数字0-9

Here is my code: 这是我的代码:

import numpy as np
from sklearn import preprocessing, svm
import pandas as pd
import pickle
from sklearn.externals import joblib

train_data = pickle.load(open('data_batch_1','rb'), encoding='latin1')
test_data = pickle.load(open('test_batch','rb'), encoding='latin1')

X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'])
y_test = np.array(test_data['labels'])

clf = svm.SVC(verbose=True)
clf.fit(X_train, y_train)

accuracy = clf.score(X_test, y_test)

joblib.dump(clf, 'Cifar-10-clf.pickle')

print(accuracy)

Does anyone know what my problem could be or can point me to resources to solve this? 有谁知道我的问题可能是什么,或者可以指出我的资源来解决这个问题?

I'm not sure but I think that you need to tune the parameters of SVC. 我不确定,但我认为您需要调整SVC的参数。

I tested some parameters for learning then I got an 0.318 accuracy. 我测试了一些学习参数,然后获得了0.318准确度。

here is code: 这是代码:

# coding: utf-8

import numpy as np
from sklearn import preprocessing, svm
import cPickle

train_data = cPickle.load(open('data/data_batch_1', 'rb'))
test_data = cPickle.load(open('data/test_batch', 'rb'))

X_train = np.array(train_data['data'])
y_train = np.array(train_data['labels'])
X_test = np.array(test_data['data'][:1000])
y_test = np.array(test_data['labels'][:1000])

clf = svm.SVC(kernel='linear', C=10, gamma=0.01)
clf.fit(X_train, y_train)

accuracy = clf.score(X_test, y_test)

print "Accuracy: ", accuracy

And I recommend grid search function for auto tuning the hyper-parameters. 并且我建议使用grid search function来自动调整超参数。

This is public documents about tuning the hyper-parameters in scikit-learn 是有关tuning the hyper-parameters scikit-learn中tuning the hyper-parameters公共文档

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM