[英]How to restore tensorflow inceptions checkpoint file (ckpt)?
I have inception_resnet_v2_2016_08_30.ckpt
file which is a pre-trained inception model.我有inception_resnet_v2_2016_08_30.ckpt
文件,它是一个预训练的初始模型。 I want to restore this model using我想使用恢复这个模型
saver.restore(sess, ckpt_filename)
But for that, I will be required to write the set of variables that were used while training this model.但为此,我将需要编写训练此模型时使用的变量集。 Where can I find those (a script, or detailed description)?我在哪里可以找到那些(脚本或详细描述)?
First of you have get the network architecture in memory.首先,您已获得内存中的网络架构。 You can get the network architecture from here你可以从这里获得网络架构
Once you have this program with you, use the following approach to use the model:有了这个程序后,请使用以下方法来使用模型:
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
height = 299
width = 299
channels = 3
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)
With this you have all the network in memory, Now you can initialize the network with checkpoint file(ckpt) by using tf.train.saver:有了这个,您就拥有了内存中的所有网络,现在您可以使用 tf.train.saver 使用检查点文件(ckpt)初始化网络:
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")
If you want to do bottle extractions, its simple like lets say you want to get features from last layer, then simply you have to declare predictions = end_points["Logits"]
If you want to get it for other intermediate layer, you can get those names from the above program inception_resnet_v2.py如果你想做瓶子提取,它很简单,比如你想从最后一层获取特征,那么你只需要声明predictions = end_points["Logits"]
如果你想为其他中间层获取它,你可以得到上面程序中的那些名字 inception_resnet_v2.py
After that you can call: output = sess.run(predictions, feed_dict={X:batch_images})
之后你可以调用: output = sess.run(predictions, feed_dict={X:batch_images})
I believe the MetaGraph
mechanism is what you need.我相信MetaGraph
机制正是您所需要的。
EDIT: additionally, take a look at tf.train.NewCheckpointReader
-- it has a get_variable_to_shape_map()
method.编辑:另外,看看tf.train.NewCheckpointReader
- 它有一个get_variable_to_shape_map()
方法。 See unit test .请参阅单元测试。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.