简体   繁体   English

如何并行化 scikit-learn SVM (SVC) 分类器的 .predict() 方法?

[英]How to parallelise .predict() method of a scikit-learn SVM (SVC) Classifier?

I recently came across a requirement that I have a .fit() trained scikit-learn SVC Classifier instance and need to .predict() lots of instances.我最近遇到一个要求,我有一个.fit()训练有素的scikit-learn SVC分类器实例,并且需要.predict()很多实例。

Is there a way to parallelise only this .predict() method by any scikit-learn built-in tools?有没有办法通过任何scikit-learn内置工具仅并行化这个.predict()方法?

from sklearn import svm

data_train = [[0,2,3],[1,2,3],[4,2,3]]
targets_train = [0,1,0]

clf = svm.SVC(kernel='rbf', degree=3, C=10, gamma=0.3, probability=True)
clf.fit(data_train, targets_train)

# this can be very large (~ a million records)
to_be_predicted = [[1,3,4]]
clf.predict(to_be_predicted)

If somebody does know a solution, I will be more than happy if you could share it.如果有人确实知道解决方案,如果您能分享它,我会非常高兴。

This may be buggy, but something like this should do the trick.这可能有问题,但像这样的事情应该可以解决问题。 Basically, break your data into blocks and run your model on each block separately in a joblib.Parallel loop.基本上,将您的数据分解成块并在joblib.Parallel循环中分别在每个块上运行您的模型。

from sklearn.externals.joblib import Parallel, delayed

n_cores = 2
n_samples = to_be_predicted.shape[0]
slices = [
    (n_samples*i/n_cores, n_samples*(i+1)/n_cores))
    for i in range(n_cores)
    ]

results = np.vstack( Parallel( n_jobs = n_cores )( 
    delayed(clf.predict)( to_be_predicted[slices[i_core][0]:slices[i_core][1]
    for i_core in range(n_cores)
    ))

Working example from above...上面的工作示例...

from joblib import Parallel, delayed
from sklearn import svm

data_train = [[0,2,3],[1,2,3],[4,2,3]]
targets_train = [0,1,0]

clf = svm.SVC(kernel='rbf', degree=3, C=10, gamma=0.3, probability=True)
clf.fit(data_train, targets_train)

to_be_predicted = np.array([[1,3,4], [1,3,4], [1,3,5]])
clf.predict(to_be_predicted)

n_cores = 3

parallel = Parallel(n_jobs=n_cores)
results = parallel(delayed(clf.predict)(to_be_predicted[i].reshape(-1,3))
    for i in range(n_cores))

np.vstack(results).flatten()
array([1, 1, 0])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM