简体   繁体   English

如何从张量流检查点文件正确恢复网络训练?

[英]How to resume properly the training of a network from a tensorflow checkpoint file?

I am struggling to restore a model for one day without any success.我正在努力恢复模型一天而没有任何成功。 My code consists of a class TF_MLPRegressor() , where I define the network architecture within the constructor.我的代码由一个class TF_MLPRegressor() ,我在其中定义了构造函数中的网络架构。 Then I invoke the fit() function to do the training.然后我调用fit()函数进行训练。 So this is how I save a simple Perceptron model with 1 hidden layer within the fit() function:所以这就是我在fit()函数中保存一个带有 1 个隐藏层的简单感知器模型的方法:

            starting_epoch = 0
            # Launch the graph
            tf.set_random_seed(self.random_state)   # fix the random seed before creating the Session in order to take effect!
            if hasattr(self, 'sess'):
                self.sess.close()
                del self.sess   # delete Session to release memory
                gc.collect()
            self.sess = tf.Session(config=self.config) # save the session to predict from new data
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(max_to_keep=2)  # max_to_keep=2 means to not keep more than 2 checkpoint files
            self.sess.run(tf.global_variables_initializer())

# ... (each 100 epochs)

            saver.save(self.sess, self.checkpoint_dir+"/resume", global_step=epoch)

Then I create a new TF_MLPRegressor() instance with exactly the same input parameter values and invoke the fit() function to restore the model like this:然后我使用完全相同的输入参数值创建一个新的TF_MLPRegressor()实例并调用fit()函数来恢复模型,如下所示:

    self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
    ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
    starting_epoch = int(ckpt.split('-')[-1])
    metagraph = ".".join([ckpt, 'meta'])
    saver = tf.train.import_meta_graph(metagraph)
    self.sess.run(tf.global_variables_initializer())    # Initialize variables
    lhl = tf.trainable_variables()[2]
    lhlA = lhl.eval(session=self.sess)
    saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model
    lhlB = lhl.eval(session=self.sess)
    print lhlA == lhlB

lhlA and lhlB are the last hidden layer weights before and after restoring and according to my code they match completely, namely the saved model is not loaded to the session. lhlAlhlB是恢复前后的最后一个隐藏层权重,根据我的代码它们完全匹配,即保存的模型没有加载到会话中。 What am I doing wrong?我究竟做错了什么?

I found a workaround!我找到了解决方法! Strangely the metagraph does not contain all the variables that I defined or assigns to them new names.奇怪的是,元图不包含我定义或分配给它们新名称的所有变量。 For examples in the constructor I define the tensors that will carry the input feature vectors and the experimental values:对于构造函数中的示例,我定义了将携带输入特征向量和实验值的张量:

self.x = tf.placeholder("float", [None, feat_num], name='x')
self.y = tf.placeholder("float", [None], name='y')

However, when I do tf.reset_default_graph() and load the metagraph, I get the following list of variables:但是,当我执行tf.reset_default_graph()并加载元图时,我得到以下变量列表:

[
<tf.Variable 'Variable:0' shape=(300, 300) dtype=float32_ref>, 
<tf.Variable 'Variable_1:0' shape=(300,) dtype=float32_ref>, 
<tf.Variable 'Variable_2:0' shape=(300, 1) dtype=float32_ref>, 
<tf.Variable 'Variable_3:0' shape=(1,) dtype=float32_ref>
]

For the record, each input feature vector has 300 features.作为记录,每个输入特征向量有 300 个特征。 Anyway, when I later try to initiate training using:无论如何,当我后来尝试使用以下方法开始培训时:

_, c, p = self.sess.run([self.optimizer, self.cost, self.pred], 
feed_dict={self.x: batch_x, self.y: batch_y, self.isTrain: True})

I get an error like:我收到如下错误:

"TypeError: Cannot interpret feed_dict key as Tensor: Tensor 'x' is not an element of this graph."

So, since every time I create an instance of class TF_MLPRegressor() , I define the network architecture within the constructor, I decided not to load the metagraph and it worked!因此,由于每次我创建class TF_MLPRegressor()的实例时,我都会在构造函数中定义网络架构,因此我决定不加载元图并且它起作用了! I don't know why TF doesn't save all variables into the metagraph, maybe because I define explicitly the network architecture (I don't use wrappers or default layers) like in the example below:我不知道为什么 TF 不将所有变量保存到元图中,可能是因为我明确定义了网络架构(我不使用包装器或默认层),如下例所示:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

To sum up, I save my models as described in my 1st message but to restore them I use this:总而言之,我按照我的第一条消息中的描述保存了我的模型,但是为了恢复它们,我使用了这个:

saver = tf.train.Saver(max_to_keep=2)
self.sess = tf.Session(config=self.config)  # create a new session to load saved variables
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(self.checkpoint_dir)
saver.restore(sess=self.sess, save_path=ckpt)   # Restore model weights from previously saved model

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 Tensorflow Keras无法在初始时从检查点文件正确恢复训练 - Tensorflow Keras cannot properly resume training at initial epoch from checkpoint file 是否可以从 Tensorflow 中的检查点 model 恢复训练? - Is it possible to resume training from a checkpoint model in Tensorflow? TensorFlow/Keras:如何使用 model.checkpoint() 恢复训练? - TensorFlow/Keras: How to resume training using model.checkpoint()? 如何在张量流中从* .meta恢复训练? - How to resume training from *.meta in tensorflow? 我正在尝试从某个检查点(Tensorflow)恢复训练,因为我使用的是 Colab,而 12 小时还不够 - I am trying to resume training from a certain checkpoint (Tensorflow) because I'm using Colab and 12 hours aren't enough Huggingface Transformer - GPT2 从保存的检查点恢复训练 - Huggingface Transformer - GPT2 resume training from saved checkpoint Tensorflow 停止并恢复训练 - Tensorflow stop and resume training Tensorflow 使用 MirroredStrategy() 恢复训练 - Tensorflow resume training with MirroredStrategy() 如何从 Tensorflow 中的检查点文件加载单个张量? - How can I load a single tensor from a checkpoint file in Tensorflow? 如何从 TensorFlow 中的检查点获取变量和 .pb 文件? - How to get variables and .pb file from checkpoint in TensorFlow?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM