For a machine learning project, i'm trying to predict a categorical outcome variable using features extracted from text.
Using cross validation, i split my X and Y into a test set and training set. The training set is trained using a pipeline. However, when i compute the performance using X from my test set my performance is 0.0. This is while there are no features extracted from X_test yet.
Is it possible to split the dataset within the pipeline?
My code:
X, Y = read_data('development2.csv')
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33, random_state=42)
train_pipeline = Pipeline([('vect', CountVectorizer()), #ngram_range=(1,2), analyzer='word'
('tfidf', TfidfTransformer(use_idf=False)),
('clf', OneVsRestClassifier(SVC(kernel='linear', probability=True))),
])
train_pipeline.fit(X_train, Y_train)
predicted = train_pipeline.predict(X_test)
print accuracy_score(Y_test, predicted)
The traceback when using SVC:
File "/Users/Robbert/Documents/pipeline.py", line 62, in <module>
train_pipeline.fit(X_train, Y_train)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/pipeline.py", line 130, in fit
self.steps[-1][-1].fit(Xt, y, **fit_params)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/svm/base.py", line 138, in fit
y = self._validate_targets(y)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/svm/base.py", line 441, in _validate_targets
y_ = column_or_1d(y, warn=True)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/utils/validation.py", line 319, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (670, 5)
I solved the problem.
The target variable (Y) did not have the appropriate format. The variables were stored like this: [[0 0 0 0 1],[0 0 1 0 0]]
. I converted this to a different array format like this: [5, 3]
.
This did the trick for me.
Thanks for all answers.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.