[英]how to store/save and restore tensorflow DNNClassifier(No variables to save)
[英]How to save&restore DNNClassifier trained in TensorFlow python; iris example
我是TensorFlow的新手,幾天前才開始學習。 我已完成本教程( https://www.tensorflow.org/versions/r0.9/tutorials/tflearn/index.html#tf-contrib-learn-quickstart )並將完全相同的想法應用於我自己的數據集。 (出來很不錯!)
現在,我想保存並恢復經過培訓的DNNClassifier以供進一步使用。 如果有人知道如何操作,請使用上面鏈接中的iris示例代碼告訴我。 感謝您的幫助!
找到了解決方案嗎? 如果沒有,您可以在創建DNNClassifier時在構造函數上指定model_dir參數,這將創建此目錄中的所有檢查點和文件(保存步驟)。 如果要執行還原步驟,只需創建另一個傳遞相同model_dir參數(還原階段)的DNNClassifier,這將從第一次創建的文件中還原模型。
希望這對你有所幫助。
以下是我的代碼......
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING, target_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST, target_dtype=np.int)
x_train, x_test, y_train, y_test = training_set.data, test_set.data, training_set.target, test_set.target
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, model_dir="path_to_my_local_dir")
# print classifier.model_dir
# Fit model.
print "start fitting model..."
classifier.fit(x=x_train, y=y_train, steps=200)
print "finished fitting model!!!"
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=x_test, y=y_test)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
#Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict_proba(new_samples)
print ('Predictions: {}'.format(str(y)))
#---------------------------------------------------------------------------------
#model_dir below has to be the same as the previously specified path!
new_classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, model_dir="path_to_my_local_dir")
accuracy_score = new_classifier.evaluate(x=x_test, y=y_test)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict_proba(new_samples)
print ('Predictions: {}'.format(str(y)))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.