[英]save checkpoint with Tensorflow
我的 CNN 模型有 3 個文件夾,分別是train_data, val_data, test_data.
當我訓練我的模型時,我發現准確度可能會有所不同,有時最后一個時期沒有顯示出最好的准確度。 例如,最后一個時期的准確率為 71%,但我發現在較早的時期准確度更高。 我想保存那個具有更高准確性的時代的檢查點,然后使用該檢查點在test_data
上預測我的模型
我在train_data
上訓練我的模型並在val_data
上預測並保存模型的檢查點,如下所示:
print("{} Saving checkpoint of model...". format(datetime.now()))
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
save_path = saver.save(session, checkpoint_path)
在開始tf.Session()
之前,我有這一行:
saver = tf.train.Saver()
我想知道如何保存具有更高准確性的最佳紀元,然后將此檢查點用於我的test_data
?
tf.train.Saver()
文檔描述了以下內容:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
請注意,如果您將global_step
傳遞給保護程序,您將生成包含全局步驟編號的檢查點文件。 我通常每 X 分鍾保存一次檢查點,然后回來查看結果並在適當的步長值處選擇一個檢查點。 如果您使用 tensorboard,您會發現這很直觀,因為您的所有圖形也可以按全局步驟顯示。
您可以使用CheckpointSaverListener 。
from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
ex = Experiment('test-07-05-2018')
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"
class ExampleCheckpointSaverListener(CheckpointSaverListener):
def begin(self):
print('Starting the session.')
self.prev_accuracy = 0
self.acc = 0
def after_save(self, session, global_step_value):
print('Only keep this checkpoint if it is better than the previous one')
self.acc = acc
if self.acc < self.prev_accuracy :
os.remove(tf.train.latest_checkpoint())
else:
self.prev_accuracy = self.acc
def end(self, session, global_step_value):
print('Done with the session.')
@ex.config
def my_config():
pass
@ex.automain
def main():
#build the graph of vanilla multiclass logistic regression
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
init = tf.global_variables_initializer()
y_pred_cls = tf.argmax(y_pred, dimension=1)
y_true_cls = tf.argmax(y, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
sess.run(init)
for epoch in range(25):
avg_loss = 0.
total_batch = int(mnist.train.num_examples/100)
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(100)
_, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
avg_loss += l / total_batch
saver.save(sess, checkpoint_path)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.