简体   繁体   English

如何加载由 Google 命名为 inception 的预训练张量流模型?

[英]How to load pre-trained tensorflow model named inception by Google?

I have downloaded a tensorflow checkpoint model named inception_resnet_v2_2016_08_30.ckpt .我下载了一个名为inception_resnet_v2_2016_08_30.ckpt的 tensorflow 检查点模型。

Do I need to create a graph (with all the variables) that were used when this checkpoint was created?我是否需要创建一个在创建此检查点时使用的图形(包含所有变量)?

How do I make use of this model?我如何使用这个模型?

First of you have get the network architecture in memory. 首先,您要在内存中获得网络体系结构。 You can get the network architecture from here 您可以从这里获取网络架构

Once you have this program with you, use the following approach to use the model: 拥有此程序后,请使用以下方法来使用模型:

from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope

height = 299
width = 299
channels = 3

X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
with slim.arg_scope(inception_resnet_v2_arg_scope()):
     logits, end_points = inception_resnet_v2(X, num_classes=1001,is_training=False)

With this you have all the network in memory, Now you can initialize the network with checkpoint file(ckpt) by using tf.train.saver: 这样,您就可以将所有网络存储在内存中,现在您可以使用tf.train.saver使用检查点文件(ckpt)初始化网络:

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "/home/pramod/Downloads/inception_resnet_v2_2016_08_30.ckpt")

If you want to do bottle feature extraction, its simple like lets say you want to get features from last layer, then simply you have to declare predictions = end_points["Logits"] If you want to get it for other intermediate layer, you can get those names from the above program inception_resnet_v2.py 如果要进行瓶特征提取,其简单方法例如可以说要从最后一层获取特征,那么只需声明predictions = end_points["Logits"]如果要为其他中间层获取它,则可以从上面的程序inception_resnet_v2.py获得这些名称

After that you can call: output = sess.run(predictions, feed_dict={X:batch_images}) 之后,您可以调用: output = sess.run(predictions, feed_dict={X:batch_images})

Do I need to create a graph (with all the variables) that were used when this checkpoint was created? 我需要创建一个创建此检查点时使用过的图形(包含所有变量)吗?

No, you don't. 不,你没有。

As for how to use checkpoint file (cpkt file) 至于如何使用检查点文件(cpkt文件)

1.This article ( TensorFlow-Slim image classification library ) tells you how to train your model from scratch 1.本文( TensorFlow-Slim图像分类库 )告诉您如何从头开始训练模型

2.The following is an example code from google blog 2.以下是来自Google博客的示例代码

import numpy as np
import os
import tensorflow as tf
import urllib2

from datasets import imagenet
from nets import inception
from preprocessing import inception_preprocessing

slim = tf.contrib.slim

batch_size = 3
image_size = inception.inception_v3.default_image_size

checkpoints_dir = '/root/code/model'
checkpoints_filename = 'inception_resnet_v2_2016_08_30.ckpt'
model_name = 'InceptionResnetV2'
sess = tf.InteractiveSession()
graph = tf.Graph()
graph.as_default()

def classify_from_url(url):
    image_string = urllib2.urlopen(url).read()
    image = tf.image.decode_jpeg(image_string, channels=3)
    processed_image = inception_preprocessing.preprocess_image(image,     image_size, image_size, is_training=False)
processed_images  = tf.expand_dims(processed_image, 0)

# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
    logits, _ = inception.inception_resnet_v2(processed_images, num_classes=1001, is_training=False)
probabilities = tf.nn.softmax(logits)

init_fn = slim.assign_from_checkpoint_fn(
    os.path.join(checkpoints_dir, checkpoints_filename),
    slim.get_model_variables(model_name))

init_fn(sess)
np_image, probabilities = sess.run([image, probabilities])
probabilities = probabilities[0, 0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]

plt.figure()
plt.imshow(np_image.astype(np.uint8))
plt.axis('off')
plt.show()

names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
    index = sorted_inds[i]
    print('Probability %0.2f%% => [%s]' % (probabilities[index], names[index]))

Another way of loading a pre-trained Imagenet model is加载预训练 Imagenet 模型的另一种方法是

ResNet50 ResNet50

import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
model = ResNet50()
model.summary()

InceptionV3创始V3

iport tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3
model = InceptionV3()
model.summary()

You can check a detailed explanation related to this here您可以在此处查看与此相关的详细说明

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM