简体   繁体   English

Inception Resnet V2模型冻结

[英]Inception Resnet V2 model freezing

I used InceptionResNet v2 model to train an image classification model using (Transfer Learning). 我使用InceptionResNet v2模型来使用(转移学习)训练图像分类模型。 My model is working well. 我的模型运作良好。 The problem is in freezing the model. 问题在于冻结模型。 Currently, I have: 目前,我有:

  • model.ckpt.meta model.ckpt.meta
  • model.ckpt.index model.ckpt.index
  • model.ckpt model.ckpt

I used this tutorial to freeze the model by setting the output_node_names to InceptionResnetV2/Logits/Predictions , and the model was generated correctly. 我使用教程通过将output_node_names设置为InceptionResnetV2 / Logits / Predictions来冻结模型,并且该模型已正确生成。 I have now a new file called model.pb 我现在有一个名为model.pb的新文件

The used code to build to freeze the model: 用于冻结模型的代码:

import os

import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))


def freeze_graph(model_folder, output_node_names):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    # output_node_names = "Accuracy/predictions"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            input_graph_def,  # The graph_def is used to retrieve the nodes
            output_node_names.split(",")  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

The problem comes when I want to feed this model with an input. 当我想用输入来输入该模型时,问题就来了。

First, I load the model graph using: 首先,我使用以下命令加载模型图:

def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="prefix",
            op_dict=None,
            producer_op_list=None
        )
    return graph

Then, when I explore the graph ops, I do not find the input placeholder 然后,当我浏览图形操作时,找不到输入占位符

for op in graph.get_operations():
    print(op.name)

The first input shows are: 第一个输入显示为:

prefix/batch/fifo_queue prefix/batch/n prefix/batch prefix/InceptionResnetV2/Conv2d_1a_3x3/weights prefix/InceptionResnetV2/Conv2d_1a_3x3/weights/read prefix/InceptionResnetV2/Conv2d_1a_3x3/convolution prefix/InceptionResnetV2/Conv2d_1a_3x3/BatchNorm/beta prefix/InceptionResnetV2/Conv2d_1a_3x3/BatchNorm/beta/read prefix/InceptionResnetV2/Conv2d_1a_3x3/BatchNorm/moments/Mean/reduction_indices . 前缀/批处理/ fifo_queue前缀/批处理/ n前缀/批处理前缀/ InceptionResnetV2 / Conv2d_1a_3x3 /权重前缀/ InceptionResnetV2 / Conv2d_1a_3x3 /权重/读取前缀/ InceptionResnetV2 / Conv2d_1a_3x3 /卷积前缀/ InceptionResnetV2 / 1/3 / BatchNorm / beta / read前缀/ InceptionResnetV2 / Conv2d_1a_3x3 / BatchNorm / moments / Mean / reduction_indices。 . . prefix/InceptionResnetV2/Logits/Predictions 前缀/ InceptionResnetV2 / Logits /预测

The error I get when I feed an image using: 当我使用以下图像输入图像时出现错误:

    img_path = 'img.jpg'

    img_data = imread(img_path)
    img_data = imresize(img_data, (299, 299, 3))
    img_data = img_data.astype(np.float32)
    img_data = np.expand_dims(img_data, 0)

    # print('Starting Session, setting the GPU memory usage to %f' % args.gpu_memory)
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory)
    # sess_config = tf.ConfigProto(gpu_options=gpu_options)
    persistent_sess = tf.Session(graph=graph)  # , config=sess_config)

    input_node = graph.get_tensor_by_name('prefix/batch/fifo_queue:0')
    output_node = graph.get_tensor_by_name('prefix/InceptionResnetV2/Logits/Predictions:0')

    predictions = persistent_sess.run(output_node, feed_dict={input_node: [img_data]})
    print(predictions)
    label_predicted = np.argmax(predictions[0])
    print(label_predicted)

Error: 错误:

 File /ImageClassification_TransferLearning System/ModelTraining/model/model_frezzing.py", line 96, in <module>
    predictions = persistent_sess.run(output_node, feed_dict={input_node: [img_data]})
  File "\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 895, in run
    run_metadata_ptr)
  File "\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1078, in _run
    subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
  File "\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py", line 122, in as_numpy_dtype
    return _TF_TO_NP[self._type_enum]
KeyError: 20

I found the problem!! 我发现了问题! I had to feed the model from input op called: prefix/batch:0 我不得不从输入op调用模型: prefix / batch:0

 input_node = graph.get_tensor_by_name('prefix/batch:0')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM