簡體   English   中英

使用 SVM 預測概率

[英]Predict probabilities using SVM

我寫了這段代碼,想獲得分類的概率。

from sklearn import svm
X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 4, 5, 6]
clf = svm.SVC() 
clf.probability=True
clf.fit(X, y)
prob = clf.predict_proba([[10, 10]])
print prob

我得到了這個輸出:

[[0.15376986 0.07691205 0.15388546 0.15389275 0.15386348 0.15383004 0.15384636]]

這很奇怪,因為概率應該是

[0 1 0 0 0 0 0 0]

(注意必須預測類別的樣本與第二個樣本相同)同樣,該類別獲得的概率最低。

您應該禁用probability並使用decision_function代替,因為不能保證predict_probapredict返回相同的結果。 您可以在文檔中閱讀有關它的更多信息。

clf.predict([[10, 10]]) // returns 1 as expected 

prop = clf.decision_function([[10, 10]]) // returns [[ 4.91666667  6.5         3.91666667  2.91666667  1.91666667  0.91666667
      -0.08333333]]
prediction = np.argmax(prop) // returns 1 

編輯:正如@TimH 所指出的,概率可以由clf.decision_function(X) 下面的代碼是固定的。 注意到使用predict_proba(X)指定的低概率問題,我認為答案是根據官方文檔here...。此外,它會在非常小的數據集上產生毫無意義的結果。

答案是理解 SVM 的結果概率是多少。 簡而言之,您在 2D 平面中有 7 個類和 7 個點。 SVM 試圖做的是在每個類之間找到一個線性分隔符(一對一方法)。 每次只選擇 2 個班級。 你得到的是歸一化后分類器的投票 這篇文章或這里(scikit-learn 使用 libsvm)查看更多關於libsvm 的多類 SVM 的詳細解釋。

通過稍微修改您的代碼,我們看到確實選擇了正確的類:

from sklearn import svm
import matplotlib.pyplot as plt
import numpy as np


X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 3, 4, 4]
clf = svm.SVC() 
clf.fit(X, y)

x_pred = [[10,10]]
p = np.array(clf.decision_function(x_pred)) # decision is a voting function
prob = np.exp(p)/np.sum(np.exp(p),axis=1, keepdims=True) # softmax after the voting
classes = clf.predict(x_pred)

_ = [print('Sample={}, Prediction={},\n Votes={} \nP={}, '.format(idx,c,v, s)) for idx, (v,s,c) in enumerate(zip(p,prob,classes))]

對應的輸出是

Sample=0, Prediction=0,
Votes=[ 6.5         4.91666667  3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333] 
P=[ 0.75531071  0.15505748  0.05704246  0.02098475  0.00771986  0.00283998  0.00104477], 
Sample=1, Prediction=1,
Votes=[ 4.91666667  6.5         3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333] 
P=[ 0.15505748  0.75531071  0.05704246  0.02098475  0.00771986  0.00283998  0.00104477], 
Sample=2, Prediction=2,
Votes=[ 1.91666667  2.91666667  6.5         4.91666667  3.91666667  0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.75531071  0.15505748  0.05704246  0.00283998  0.00104477], 
Sample=3, Prediction=3,
Votes=[ 1.91666667  2.91666667  4.91666667  6.5         3.91666667  0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.15505748  0.75531071  0.05704246  0.00283998  0.00104477], 
Sample=4, Prediction=4,
Votes=[ 1.91666667  2.91666667  3.91666667  4.91666667  6.5         0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.05704246  0.15505748  0.75531071  0.00283998  0.00104477], 
Sample=5, Prediction=5,
Votes=[ 3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333  6.5  4.91666667] 
P=[ 0.05704246  0.02098475  0.00771986  0.00283998  0.00104477  0.75531071  0.15505748], 
Sample=6, Prediction=6,
Votes=[ 3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333  4.91666667  6.5       ] 
P=[ 0.05704246  0.02098475  0.00771986  0.00283998  0.00104477  0.15505748  0.75531071], 

您還可以看到決策區:

X = np.array(X)
y = np.array(y)
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)

XX, YY = np.mgrid[0:100:200j, 0:100:200j]
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()])

Z = Z.reshape(XX.shape)
plt.figure(1, figsize=(4, 3))
plt.pcolormesh(XX, YY, Z, cmap=plt.cm.Paired)

for idx in range(7):
    ax.scatter(X[idx,0],X[idx,1], color='k')

在此處輸入圖片說明

您可以在文檔閱讀...

SVC 方法 decision_function 為每個樣本提供每個類別的分數(或在二元情況下每個樣本的單個分數)。 當構造函數選項概率設置為 True 時,啟用類成員概率估計(來自方法 predict_proba 和 predict_log_proba)。 在二元情況下,概率使用 Platt scaling 進行校准:SVM 分數的邏輯回歸,通過對訓練數據進行額外的交叉驗證來擬合。 在多類情況下,這是根據 Wu 等人的擴展。 (2004)。

毋庸置疑,Platt 縮放中涉及的交叉驗證對於大型數據集來說是一項昂貴的操作 此外,概率估計可能與分數不一致,因為分數的“argmax”可能不是概率的 argmax。 (例如,在二元分類中,根據 predict_proba樣本可能被 predict 標記為屬於概率 < 1/2 的類別。)眾所周知,Platt 的方法也存在理論問題。 如果需要置信度分數,但這些分數不一定是概率,那么建議設置概率=False 並使用decision_function 代替predict_proba。

Stack Overflow 用戶中也有很多關於此功能的混淆,正如您在此線程線程中所見。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM