簡體   English   中英

__init __()得到了意外的關鍵字參數'n_splits'錯誤

[英]__init__() got an unexpected keyword argument 'n_splits' ERROR

我打算嘗試此鏈接中的代碼:

我從引用StratifiedKFold(n_splits=60)的行中得到錯誤。 誰能告訴我如何解決這個錯誤?

這是代碼:

import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.cross_validation import StratifiedKFold

iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y
X, y

cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
                     random_state=random_state)

mean_tpr = 0.0
mean_fpr = np.linspace(0, 1, 100)

這是錯誤:

TypeError                                 Traceback (most recent call last)
<ipython-input-227-2af2773f4987> in <module>()
----> 1 sklearn.cross_validation.StratifiedKFold(n_splits=6)
      2 #cv = StratifiedKFold(n_splits=6,  shuffle=True, random_state=1)
      3 classifier = svm.SVC(kernel='linear', probability=True,
      4                      random_state=random_state)
      5 

TypeError: __init__() got an unexpected keyword argument 'n_splits'

導入sklearn.cross-validation模塊時,您沒有收到任何警告。 這意味着您安裝的版本小於0.18。

如果您的scikit-learn版本< 0.18 ,則更改以下幾行:(摘自StratifiedKFold文檔中的0.17版

#Notice the extra parameter y and change of name for n_splits to n_folds
cv = StratifiedKFold(y, n_folds=6)

#Also note that the cv is called directly in for loop
for train_index, test_index in cv:
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]

如果您的scikit-learn版本>=0.18 ,則只有您可以對cv使用n_splits參數:(摘自StratifiedKFold當前文檔 ,這是我認為的意思)

#Notice the extra parameter y is removed here
cv = StratifiedKFold(n_splits=6)

#Also note that the cv.split() is called here (opposed to cv in ver 0.17 above)
for train_index, test_index in cv.split(X, y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]

建議

將您的scikit-learn更新到最新版本0.18。 因為您可以通過直接搜索找到的大多數文檔都是此版本,這會讓您感到困惑。

編輯:

我已經在這里回答了您的類似問題:- 交叉驗證問題

因此,下一次,請提及您在問題本身中使用的庫的版本,並記住訪問它們的相關文檔,而不是其他文檔。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM