简体   繁体   中英

How to save TFIDF vectorizer in scikit learn?

I am developing a spam classifier using scikit learn.

Here is my vectorizing code

vectorizer = TfidfVectorizer(
    analyzer='word', 
    sublinear_tf=True,
    strip_accents='unicode',
    token_pattern=r'\w{1,}',
    ngram_range=(1, 1),
    max_features=10000)


tfidf = vectorizer.fit(data['text'])
features = vectorizer.transform(data['text'])

import pickle
pickle.dump(tfidf, open('tfidf.pickle', 'wb'))

Here is what I am doing to predict new input

import joblib

model = joblib.load('model')

vect = pickle.load(open('tfidf.pickle', 'rb'))

new = vect.transform(['some new text...'])

mod.predict(new)

When I open vectorizer file (tfidf.pickle) and try to predict a new message it shows me the error as

ValueError : X.shape[1] = 7148 should be equal to 38011, the number of features at training time

The error message says that your model expects an input with size 38011, while your TF-IDF vectorizer outputs vectors of dimension 7148. You have a model/preprocessor mismatch here, ie your model was trained on vectors that are 38011-dimensional while your TF-IDF outputs vectors that are 7148-dimensional.

A good way to avoid this preprocessing/model mismatch is to use scikit-learn pipelines . For instance here you could train your model and your TF-IDF vectorizer with the following piece of code (example with a logistic regression here):

from sklearn.preprocessing import make_pipeline

vectorizer = TfidfVectorizer(...your TF-IDF arguments...)
model = LogisticRegression(...your model arguments...)
pipeline = make_pipeline(vectorizer, model)

pipeline.fit(X, y)

And then you can serialize and load your pipeline with pickle or joblib (eg pickle.dump(pipeline, open('spam_pipeline.pickle', 'wb')) , then pipeline = pickle.load(open('spam_pipeline.pickle', 'rb')) , similarly to what you were already doing.

You can directly use the predict method of the pipeline to get a prediction.

Let me know if you need more details.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM