简体   繁体   English

是否可以使用`sklearn`重新训练保存的神经网络

[英]Is it possible to retrain a saved Neural Network using `sklearn`

I'm working on a project that looks to classify tweets and am using sklearn 's neural network model . 我正在开发一个项目,该项目旨在对推文进行分类,并且正在使用sklearn神经网络模型 Is it possible to retrain this using sklearn and if so please guide me in the right direction. 是否可以使用sklearn对其进行重新培训,如果可以,请以正确的方向指导我。 Also, is it worth it to retrain a model or should I just adjust values when constructing a network. 另外,重新训练模型是否值得,或者在构建网络时我应该只是调整值。

You may try the following. 您可以尝试以下方法。

from sklearn.externals import joblib
##Suppose your trained model is named MyTrainedModel
##This is how you save it in a file called MyTrainedModelFile.txt.
joblib.dump(MyTrainedModel, 'MyTrainedModelFile.txt')
##Later you can recall the model and use it
Loaded_model = joblib.load('MyTrainedModelFile.txt')

The tutorial is here . 本教程在这里

Please let me know if this is what you wanted. 请让我知道这是否是您想要的。

You could very well do that using the partial_fit method that MLPClasifier offers. 您可以使用MLPClasifier提供的partial_fit方法很好地做到这MLPClasifier I have written a sample code for doing it. 我已经为此编写了示例代码。 You could very well retrain your saved model if you get data in batches and training is a costy operation for you so can't afford to train on the entire dataset each and every time you get a new batch of data. 如果您批量获取数据,那么很好地重新训练您保存的模型,而训练对您来说是一项昂贵的操作,因此每次获得新一批数据时都无法承受对整个数据集的训练。

import pickle
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=1000, n_classes=4, n_features=11,
                       n_informative=4, weights=[0.25,0.25,0.25,0.25],
                       random_state=0)

x_batch1 = X[0:500]
y_batch1 = y[0:500]

x_batch2 = X[500:999]
y_batch2 = y[500:999]

clf = MLPClassifier()
clf.partial_fit(x_batch1, y_batch1, classes = np.unique(y))  # you need to pass the classes when you fit for the first time


pickle.dump(clf, open("MLP_classifier", 'wb'))
restored_clf = pickle.load(open("MLP_classifier", 'rb'))

restored_clf.partial_fit(x_batch2, y_batch2)

Hope this helps! 希望这可以帮助!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM