簡體   English   中英

使用 Tensorflow 保存檢查點

[英]save checkpoint with Tensorflow

我的 CNN 模型有 3 個文件夾,分別是train_data, val_data, test_data.

當我訓練我的模型時,我發現准確度可能會有所不同,有時最后一個時期沒有顯示出最好的准確度。 例如,最后一個時期的准確率為 71%,但我發現在較早的時期准確度更高。 我想保存那個具有更高准確性的時代的檢查點,然后使用該檢查點在test_data上預測我的模型

我在train_data上訓練我的模型並在val_data上預測並保存模型的檢查點,如下所示:

    print("{} Saving checkpoint of model...". format(datetime.now()))
    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
    save_path = saver.save(session, checkpoint_path)

在開始tf.Session()之前,我有這一行:

saver = tf.train.Saver()

我想知道如何保存具有更高准確性的最佳紀元,然后將此檢查點用於我的test_data

tf.train.Saver()文檔描述了以下內容:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

請注意,如果您將global_step傳遞給保護程序,您將生成包含全局步驟編號的檢查點文件。 我通常每 X 分鍾保存一次檢查點,然后回來查看結果並在適當的步長值處選擇一個檢查點。 如果您使用 tensorboard,您會發現這很直觀,因為您的所有圖形也可以按全局步驟顯示。

https://www.tensorflow.org/api_docs/python/tf/train/Saver

您可以使用CheckpointSaverListener

from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data

ex = Experiment('test-07-05-2018')    

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"

class ExampleCheckpointSaverListener(CheckpointSaverListener):
    def begin(self):
       print('Starting the session.')
       self.prev_accuracy = 0
       self.acc = 0

   def after_save(self, session, global_step_value):
       print('Only keep this checkpoint if it is better than the previous one')
       self.acc = acc 
       if self.acc <  self.prev_accuracy :
            os.remove(tf.train.latest_checkpoint())
       else:
            self.prev_accuracy = self.acc

   def end(self, session, global_step_value):
       print('Done with the session.')

@ex.config
def my_config():
pass

@ex.automain
def main():
      #build the graph of vanilla multiclass logistic regression
      x = tf.placeholder(tf.float32, [None, 784])
      y = tf.placeholder(tf.float32, [None, 10]) 
      W = tf.Variable(tf.zeros([784, 10]))
      b = tf.Variable(tf.zeros([10]))
      y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
      loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
      optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
      init = tf.global_variables_initializer()
      y_pred_cls = tf.argmax(y_pred, dimension=1)
      y_true_cls = tf.argmax(y, dimension=1)
      correct_prediction = tf.equal(y_pred_cls, y_true_cls)
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
      saver = tf.train.Saver()
      listener = ExampleCheckpointSaverListener()
      saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
      with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
          sess.run(init)
          for epoch in range(25):
              avg_loss = 0.
              total_batch = int(mnist.train.num_examples/100)
              # Loop over all batches
              for i in range(total_batch):
                  batch_xs, batch_ys = mnist.train.next_batch(100)
                  _, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
                  avg_loss += l / total_batch
                  saver.save(sess, checkpoint_path)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM