简体   繁体   English

如何将 SelectKBest 合并到 SKlearn 管道中

[英]How do I incorporate SelectKBest in an SKlearn pipeline

I'm trying to build a text classifier with sklearn.我正在尝试使用 sklearn 构建文本分类器。 The idea is to:这个想法是:

  1. Vectorize training corpus using TfidfVectorizer使用TfidfVectorizer 向量化训练语料库
  2. Select the top 20,000 features that result (or using all features if the resultant number is below 20k) using SelectKBest Select 使用SelectKBest产生的前 20,000 个特征(或者如果结果数量低于 20k,则使用所有特征)
  3. Feed these features into a Logistic Regression Classifier将这些特征输入逻辑回归分类器

I've set it up successfully as follows:我已经成功设置如下:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression

vectorizer = TfidfVectorizer()
x_train = vectorizer.fit_transform(df_train["input"])
selector = SelectKBest(f_classif, k=min(20000, x_train.shape[1]))
selector.fit(x_train, df_train["label"].values)
x_train = selector.transform(x_train)
classifier = LogisticRegression()
classifier.fit(x_train, df_train["label"])

I would now like to wrap all this up into a pipeline, and share the pipeline so it can be used by others for their own text data.我现在想将所有这些打包到一个管道中,并共享该管道,以便其他人可以将其用于他们自己的文本数据。 Yet, I can't figure how to get SelectKBest to achieve the same behavior as it did above, ie accept min(20000, n_features from vectorizer output) as k.然而,我不知道如何让 SelectKBest 实现与上面相同的行为,即接受 min(20000, n_features from vectorizer output) 作为 k。 If I were to simply leave it as k=20000 as below, the pipeline doesn't work (throws an error) when fitting new corpora with less than 20k vectorized features.如果我将其简单地保留为 k=20000,如下所示,当拟合具有少于 20k 个矢量化特征的新语料库时,管道将不起作用(引发错误)。

pipe = Pipeline([
            ("vect",TfidfVectorizer()),
            ("selector",SelectKBest(f_classif, k=20000)),
            ("clf",LogisticRegression())])

As @vivek kumar pointed out you need to override the _check_params method of SelectKBest and add your logic to it as shown below:正如@vivek kumar 指出的那样,您需要覆盖SelectKBest_check_params方法并将您的逻辑添加到其中,如下所示:

class MySelectKBest(SelectKBest):
    def _check_params(self, X, y):
        if (self.k >= X.shape[1]):
            warnings.warn("Less than %d number of features found, so setting k as %d" % (self.k, X.shape[1]),
                      UserWarning)
            self.k = X.shape[1]
        if not (self.k == "all" or 0 <= self.k):
            raise ValueError("k should be >=0, <= n_features = %d; got %r. "
                             "Use k='all' to return all features."
                             % (X.shape[1], self.k)) 

I have also set a warning in case if the number of features found is less than the threshold set.如果找到的功能数量少于设置的阈值,我还设置了警告。 Now let's see a working example of the same:现在让我们看一个相同的工作示例:

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
import warnings

categories = ['alt.atheism', 'comp.graphics',
              'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware',
              'comp.windows.x', 'misc.forsale', 'rec.autos']
newsgroups = fetch_20newsgroups(categories=categories)
y_true = newsgroups.target

# newsgroups result in 47K odd features after performing TFIDF vectorizer

# Case 1: When K < No. of features - the regular case
pipe = Pipeline([
            ("vect",TfidfVectorizer()),
            ("selector",MySelectKBest(f_classif, k=30000)),
            ("clf",LogisticRegression())])

pipe.fit(newsgroups.data, y_true)
pipe.score(newsgroups.data, y_true)
#0.968

#Case 2: When K > No. of cases - the one with an issue

pipe = Pipeline([
            ("vect",TfidfVectorizer()),
            ("selector",MySelectKBest(f_classif, k=50000)),
            ("clf",LogisticRegression())])

pipe.fit(newsgroups.data, y_true)
UserWarning: Less than 50000 number of features found, so setting k as 47407

pipe.score(newsgroups.data, y_true)
#0.9792

Hope this helps!希望这可以帮助!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM