简体   繁体   中英

How to get the label predicted using sklearn and numpy?

I am trying to using sklearn to predict some texts using a folder where each subfolder is a collection of txt files:

import numpy
from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer
from sklearn.datasets import load_files
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from nltk.corpus import stopwords
from sklearn import svm
import os

path = 'C:\wamp64\www\machine_learning\webroot\mini_iniciais\\'

#carregando
data = load_files(path, encoding="utf-8", decode_error="replace")
labels, counts = numpy.unique(data.target, return_counts=True)
labels_str = numpy.array(data.target_names)[labels]
print(dict(zip(labels_str, counts)))

#montando
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target)
vectorizer = TfidfVectorizer(max_features=1000, decode_error="ignore")
vectorizer.fit(X_train)
X_train_vectorized = vectorizer.transform(X_train)

cls = MultinomialNB()
cls.fit(vectorizer.transform(X_train), y_train)

texts_to_predict = ["medicamento"]

result = cls.predict(vectorizer.transform(texts_to_predict))
print(result)

This the result from print(dict(zip(labels_str, counts))) :

{'PG16-PROCURADORIA-DE-SERVICOS-DE-SAUDE': 10, 'PP-PROCURADORIA-DE-PESSOAL-PG04': 10, 'PPMA-PROCURADORIA-DE-PATRIMONIO-E-MEIO-AMBIENTE-PG06': 10, 'PPREV-PROCURADORIA-PREVIDENCIARIA-PG07': 10, 'PSP-PROCURADORIA-DE-SERVICOS-PUBLICOS-PG08': 10, 'PTRIB-PROCURADORIA-TRIBUTARIA-PG03': 10}

But the result from cls.predict is only an int on an array:

[0]

Or even [1], [3] etc... when I change the texts_to_predict value.

So, how can I get one of the sub-folders' name as result of prediction?

According to the documentation of load_files , the attribute target_names of the returned data holds

[t]he names of target classes.

So, consider using something like

print([data.target_names[x] for x in result])

instead of

print(result)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM