[英]After Training a Naive Bayes Text Classification Algorithm, how to predict topic of a Single text file
我已經使用文本和訓練數據訓練並測試了朴素貝葉斯算法。 現在我想預測單個文本文件的主題。
這是我的代碼,
#importing test, train data
import sklearn.datasets as skd
categories = ['business', 'entertainment','local', 'sports', 'world']
sinhala_train = skd.load_files('Cleant data\stemmed_filtered_sinhala-set1', categories= categories, encoding= 'utf-8')
sinhala_test = skd.load_files('Cleant data\stemmed_filtered_sinhala-set2',categories= categories, encoding= 'utf-8')
name_file = "adaderana_67571.txt"
A = open(name_file, encoding='utf-8')
new_file = A.read()
from sklearn.feature_extraction.text import CountVectorizer
count_vectorization = CountVectorizer()
train_data_tf = count_vectorization.fit_transform(sinhala_train.data)
train_data_tf.shape
from sklearn.feature_extraction.text import TfidfTransformer
tfidf_trans = TfidfTransformer()
train_data_tfidf = tfidf_trans.fit_transform(train_data_tf)
train_data_tfidf.shape
from sklearn.naive_bayes import MultinomialNB
clf = MultinomialNB().fit(train_data_tfidf, sinhala_train.target)
test_data_tf = count_vectorization.transform(sinhala_test.data)
test_data_tfidf = tfidf_trans.fit_transform(test_data_tf)
predicted = clf.predict(test_data_tfidf)
from sklearn import metrics
from sklearn.metrics import accuracy_score
print("Accuracy of the model:", accuracy_score(sinhala_test.target, predicted))
print(metrics.classification_report(sinhala_test.target, predicted, target_names=sinhala_test.target_names)),
metrics.confusion_matrix(sinhala_test.target, predicted)
這是我的 output,
Accuracy of the model: 0.864
precision recall f1-score support
business 0.78 0.94 0.85 100
entertainment 0.95 0.86 0.90 100
local 0.89 0.65 0.75 100
sports 0.91 0.93 0.92 100
world 0.83 0.94 0.88 100
micro avg 0.86 0.86 0.86 500
macro avg 0.87 0.86 0.86 500
weighted avg 0.87 0.86 0.86 500
array([[94, 2, 4, 0, 0],
[ 2, 86, 2, 4, 6],
[19, 0, 65, 5, 11],
[ 1, 3, 1, 93, 2],
[ 5, 0, 1, 0, 94]], dtype=int64)
現在我想預測文本文件new_file
的主題。
有人可以幫我編寫代碼來預測這個文本文件的主題。
我解決了我的問題。 這是我用來預測主題的代碼。
docs_new1 = sinhala_test_1
docs_new = [docs_new1]
X_new_counts = count_vectorization.transform(docs_new)
X_new_tfidf = tfidf_trans.transform(X_new_counts)
predicted_topic = clf.predict(X_new_tfidf)
for doc, category in zip(docs_new, predicted_topic):
topic = ( sinhala_train.target_names[category])
return topic
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.