简体   繁体   中英

Why doesn't fit_transform work in this sklearn Pipeline example?

I an new to sklearn Pipeline and following a sample code. I saw in other examples that we can do pipeline.fit_transform(train_X) , so I tried the same thing on the pipeline here pipeline.fit_transform(X) , but it gave me an error

" return self.fit(X, **fit_params).transform(X)

TypeError: fit() takes exactly 3 arguments (2 given)"

If I remove the svm part and defined the pipeline as pipeline = Pipeline([("features", combined_features)]) , I still saw the error.

Does anyone know why fit_transform doesn't work here?

from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.grid_search import GridSearchCV

from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest

iris = load_iris()

X, y = iris.data, iris.target

# This dataset is way to high-dimensional. Better do PCA:
pca = PCA(n_components=2)

# Maybe some original features where good, too?
selection = SelectKBest(k=1)

# Build estimator from PCA and Univariate selection:

combined_features = FeatureUnion([("pca", pca), ("univ_select", selection)])

# Use combined features to transform dataset:
X_features = combined_features.fit(X, y).transform(X)

svm = SVC(kernel="linear")

# Do grid search over k, n_components and C:

pipeline = Pipeline([("features", combined_features), ("svm", svm)])

param_grid = dict(features__pca__n_components=[1, 2, 3],
                  features__univ_select__k=[1, 2],
                  svm__C=[0.1, 1, 10])

grid_search = GridSearchCV(pipeline, param_grid=param_grid, verbose=10)
grid_search.fit(X, y)
print(grid_search.best_estimator_)

You get an error in the above example because you also need to pass the labels to your pipeline. You should be calling pipeline.fit_transform(X,y) . The last step in your pipeline is a classifier, SVC and the fit method of a classifier also requires the labels as a mandatory argument. The fit method of all classifiers also require labels because the classification algorithms use these labels to train the weights in your classifier.

Similarly, even if you remove the SVC , you still get an error because the fit method of SelectKBest class also requires both X and y .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM