简体   繁体   English

多类分类中的输入形状不好()

[英]Bad input shape () in multi-class classification

I'am performing a multi-class classification task using sci-kit learn. 我正在使用sci-kit学习执行多类分类任务。 In the setup i created, i want to compare different classification algorithms. 在我创建的设置中,我想比较不同的分类算法。

I use a pipeline, where text is inserted as X and Y is the class (multi-class, N = 5). 我使用管道,其中将文本插入为X,Y是类(多类,N = 5)。 Textual features are extracted in the pipeline using TfidfVectorizer(). 使用TfidfVectorizer()在管道中提取文本特征。

KNN does the job, but other classifiers give this: ValueError: bad input shape (670, 5) KNN可以完成这项工作,但其他分类器可以做到这一点: ValueError: bad input shape (670, 5)

Full traceback: 完整回溯:

"/Users/Robbert/pipeline.py", line 62, in <module>
train_pipeline.fit(X_train, Y_train)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/pipeline.py", line 130, in fit
self.steps[-1][-1].fit(Xt, y, **fit_params)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/svm/base.py", line 138, in fit
y = self._validate_targets(y)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/svm/base.py", line 441, in _validate_targets
y_ = column_or_1d(y, warn=True)
File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/utils/validation.py", line 319, in column_or_1d
raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (670, 5)

The code i use: 我使用的代码:

def read_data(f):
data = []
for row in csv.reader(open(f), delimiter=';'):
    if row:
        plottext = row[8]
        target = { 'Age': row[4] }
        data.append((plottext, target))
(X, Ycat) = zip(*data) 
Y = DictVectorizer().fit_transform(Ycat)
Y = preprocessing.LabelBinarizer().fit_transform(Y)
return (X, Y)

X, Y = read_data('development2.csv')

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33, random_state=42)

###KNN Pipeline
#train_pipeline = Pipeline([
#    ('vect', TfidfVectorizer(ngram_range=(1, 3), min_df=1)),
#    ('clf', KNeighborsClassifier(n_neighbors=350, weights='uniform'))])

###Logistic regression Pipeline
#train_pipeline = Pipeline([
#    ('vect', TfidfVectorizer(ngram_range=(1, 3), min_df=1)),
#    ('clf', LogisticRegression())])

##SVC
train_pipeline = Pipeline([
('vect', TfidfVectorizer(ngram_range=(1, 3), min_df=1)),
('clf', SVC(C=1, kernel='rbf', gamma=0.001, probability=True))])

##Decision tree
#train_pipeline = Pipeline([
#    ('vect', TfidfVectorizer(ngram_range=(1, 3), min_df=1)),   
#    ('clf', DecisionTreeClassifier(random_state=0))])

train_pipeline.fit(X_train, Y_train)

predicted = train_pipeline.predict(X_test)

print accuracy_score(Y_test, predicted)

How is it possible that KNN accepts the shape of the array and other classifiers don't? KNN怎么可能接受数组的形状而其他分类器不接受? And how to change this shape? 以及如何改变这种形状?

If you compare documentation for fit(X, y) function in KNeighborsClassifier and SVC , you will see that only the former one accepts the y in the form [n_samples, n_outputs]. 如果在KNeighborsClassifierSVC中比较fit(X,y)函数的文档 ,您会看到只有前者接受[n_samples,n_outputs]形式的y。

Possible solution: why do you need LabelBinarizer at all? 可能的解决方案:为什么根本需要LabelBinarizer? Just do not use it. 只是不要使用它。

If your Y vector is of size (n_samples, n_classes) and contains at least a single row which has more than one non-zero element, then you are solving a multi-label classification problem. 如果您的Y向量的大小为(n_samples,n_classes),并且至少包含具有多个非零元素的单行,那么您正在解决多标签分类问题。 If that is the case, The multiclass and multilabel algorithms page in scikit-learn docs lists KNN as one of the classifiers that supports multi-label classification. 如果是这样,scikit-learn docs中的“ 多类和多标签算法”页面将KNN列为支持多标签分类的分类器之一。 You might want to try out other classifiers from that list 您可能要尝试该列表中的其他分类器

* sklearn.tree.DecisionTreeClassifier
* sklearn.tree.ExtraTreeClassifier
* sklearn.ensemble.ExtraTreesClassifier
* sklearn.neural_network.MLPClassifier
* sklearn.neighbors.RadiusNeighborsClassifier
* sklearn.ensemble.RandomForestClassifier
* sklearn.linear_model.RidgeClassifierCV

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM