簡體   English   中英

如何保存和恢復在TensorFlow python中訓練的DNNClassifier; 虹膜的例子

[英]How to save&restore DNNClassifier trained in TensorFlow python; iris example

我是TensorFlow的新手,幾天前才開始學習。 我已完成本教程( https://www.tensorflow.org/versions/r0.9/tutorials/tflearn/index.html#tf-contrib-learn-quickstart )並將完全相同的想法應用於我自己的數據集。 (出來很不錯!)

現在,我想保存並恢復經過培訓的DNNClassifier以供進一步使用。 如果有人知道如何操作,請使用上面鏈接中的iris示例代碼告訴我。 感謝您的幫助!

找到了解決方案嗎? 如果沒有,您可以在創建DNNClassifier時在構造函數上指定model_dir參數,這將創建此目錄中的所有檢查點和文件(保存步驟)。 如果要執行還原步驟,只需創建另一個傳遞相同model_dir參數(還原階段)的DNNClassifier,這將從第一次創建的文件中還原模型。

希望這對你有所幫助。

以下是我的代碼......

import tensorflow as tf
import numpy as np

if __name__ == '__main__':
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING, target_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST, target_dtype=np.int)

x_train, x_test, y_train, y_test = training_set.data, test_set.data, training_set.target, test_set.target

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, model_dir="path_to_my_local_dir")

# print classifier.model_dir

# Fit model.
print "start fitting model..."
classifier.fit(x=x_train, y=y_train, steps=200)
print "finished fitting model!!!"

# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=x_test, y=y_test)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

#Classify two new flower samples.
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict_proba(new_samples)
print ('Predictions: {}'.format(str(y)))

#---------------------------------------------------------------------------------
#model_dir below has to be the same as the previously specified path!
new_classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, model_dir="path_to_my_local_dir")
accuracy_score = new_classifier.evaluate(x=x_test, y=y_test)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict_proba(new_samples)
print ('Predictions: {}'.format(str(y)))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM