簡體   English   中英

在 Python 中保存經過訓練的多輸入分類算法

[英]Saving a trained multi-input classification algorithm in Python

我開發了一個腳本,根據之前手動標記的反饋預測某些文本的可能標記。 我使用了幾篇在線文章來幫助我(即: https://towardsdatascience.com/multi-label-text-classification-with-scikit-learn-30714b7819c5 )。

因為我想要每個標簽的概率,所以這是我使用的代碼:

NB_pipeline = Pipeline([
    ('clf', OneVsRestClassifier(MultinomialNB(alpha=0.3, fit_prior=True, class_prior=None))),
    ])

predictions_en = {}
for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

preds_en = pd.DataFrame(predictions_en.items())
preds_en = preds_en.sort_values(by=[1], ascending=False)
preds_en = preds_en.reset_index(drop=True)

它非常適合我的目的:它為每個可能的標簽返回一個預測。 但我的問題是,每次我嘗試進行預測時,它都會重新訓練算法。 我想做的是在腳本中訓練算法,保存訓練后的算法,將其加載到另一個進行預測的腳本中。

我希望能夠在腳本 1 中執行此操作:

for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])

這在另一個腳本中:

for category in categories_en:
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

但我似乎無法讓它工作。 當我嘗試將它分開時,它只是給了我相同的預測。

您始終可以使用pickle序列化任何 python object 包括您的。 因此,保存 model 的最簡單和最快的方法是將其序列化到一個文件中,例如model.pickle 這是在訓練 model 后的第一部分完成的。 之后,您所要做的就是檢查文件是否存在並再次使用pickle對其進行反序列化。

這是一個 function 序列化 python 對象到文件:

import pickle

def serialize(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

這是一個 function 反序列化文件中的 python 對象:

import pickle

def deserialize(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

完成訓練后,您只需調用(如果NB_pipeline是您的模型的 object):

serialize(NB_pipeline, 'model.pickle')

當你必須加載它並使用它時,只需調用:

NB_pipeline = deserialize('model.pickle')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM