簡體   English   中英

Plot SVM 決策邊界

[英]Plot SVM decision boundary

以下代碼使用多項式 kernel 和 plot 擬合虹膜數據和決策邊界的 SVM。 輸入 X 使用數據的前 2 列,即萼片長度和寬度。 但是,我在復制 output 時遇到困難,其中第 3 列和第 4 列為 X,即花瓣的長度和寬度。 如何更改 plot function 以使代碼正常工作? 提前致謝。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC

iris= datasets.load_iris()
y= iris.target
#X= iris.data[:, :2]  # sepal length and width
X= iris.data[:, 2:]   # I tried a different X but it failed.

# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
    # create a mesh to plot in
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    h = (x_max / x_min)/100
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    plt.subplot(1, 1, 1)
    Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(xx.min(), xx.max())
    plt.title(title)
    plt.show()

svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))

錯誤:

  import sys
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-57-4515f111e34d> in <module>()
      2 svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
      3 svm_mod= svm.fit(X,y)
----> 4 plotSVC('kernel='+ str('polynomial'))

<ipython-input-56-556d4a22026a> in plotSVC(title)
     10     Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
     11     Z = Z.reshape(xx.shape)
---> 12     plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
     13     plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
     14     plt.xlabel('Sepal length')

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in contourf(*args, **kwargs)
   2931                       mplDeprecation)
   2932     try:
-> 2933         ret = ax.contourf(*args, **kwargs)
   2934     finally:
   2935         ax._hold = washold

~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1853                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1854                         RuntimeWarning, stacklevel=2)
-> 1855             return func(ax, *args, **kwargs)
   1856 
   1857         inner.__doc__ = _add_data_doc(inner.__doc__,

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in contourf(self, *args, **kwargs)
   6179             self.cla()
   6180         kwargs['filled'] = True
-> 6181         contours = mcontour.QuadContourSet(self, *args, **kwargs)
   6182         self.autoscale_view()
   6183         return contours

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in __init__(self, ax, *args, **kwargs)
    844         self._transform = kwargs.pop('transform', None)
    845 
--> 846         kwargs = self._process_args(*args, **kwargs)
    847         self._process_levels()
    848 

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _process_args(self, *args, **kwargs)
   1414                 self._corner_mask = mpl.rcParams['contour.corner_mask']
   1415 
-> 1416             x, y, z = self._contour_args(args, kwargs)
   1417 
   1418             _mask = ma.getmask(z)

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _contour_args(self, args, kwargs)
   1472             args = args[1:]
   1473         elif Nargs <= 4:
-> 1474             x, y, z = self._check_xyz(args[:3], kwargs)
   1475             args = args[3:]
   1476         else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _check_xyz(self, args, kwargs)
   1508             raise TypeError("Input z must be a 2D array.")
   1509         elif z.shape[0] < 2 or z.shape[1] < 2:
-> 1510             raise TypeError("Input z must be at least a 2x2 array.")
   1511         else:
   1512             Ny, Nx = z.shape

TypeError: Input z must be at least a 2x2 array.

工作代碼

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC

iris= datasets.load_iris()
y= iris.target
#X= iris.data[:, :2]  # sepal length and width
X= iris.data[:, 2:]   # I tried a different X but it failed.

# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
    # create a mesh to plot in
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    h = (x_max - x_min)/100
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    plt.subplot(1, 1, 1)
    z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
    z = z.reshape(xx.shape)
    plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(xx.min(), xx.max())
    plt.title(title)
    plt.show()
    pass

svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))

出去:

在此處輸入圖像描述


原因:

除以零在h = (x_max / x_min)/100行中給出h中的inf

需要為h = (x_max - x_min)/100


我通過閱讀聲明的異常發現了這一點

TypeError:輸入 z 必須至少是一個 2x2 數組。

然后向后看z的形狀來自xx的形狀,它取決於h ,這是inf ,沒有意義,然后很容易解決。

我認為您應該學習如何更好地使用調試器。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM