简体   繁体   中英

How to use two models in Tensorflow object Detection API

In tensorflow Object Detection API we are using ssd_mobilenet_v1_coco_2017_11_17 model to detect 90 general objects. I want to use this model for detection. Next, I have trained faster_rcnn_inception_v2_coco_2018_01_28 model to detect a custom object. I wish to use this in the same code where I will be able to detect those 90 objects as well as my new trained custom object. How to achieve this with single code?

I have achieved this by doing the following code in detect_object.py

import numpy as np
import tensorflow as tf
import sys
from PIL import Image
import cv2

from utils import label_map_util
from utils import visualization_utils as vis_util

# ------------------ Knife Model Initialization ------------------------------ #
knife_label_map = label_map_util.load_labelmap('training/labelmap.pbtxt')
knife_categories = label_map_util.convert_label_map_to_categories(
    knife_label_map, max_num_classes=1, use_display_name=True)
knife_category_index = label_map_util.create_category_index(knife_categories)

knife_detection_graph = tf.Graph()

with knife_detection_graph.as_default():
    knife_od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('inference_graph_3/frozen_inference_graph.pb', 'rb') as fid:
        knife_serialized_graph = fid.read()
        knife_od_graph_def.ParseFromString(knife_serialized_graph)
        tf.import_graph_def(knife_od_graph_def, name='')

    knife_session = tf.Session(graph=knife_detection_graph)

knife_image_tensor = knife_detection_graph.get_tensor_by_name('image_tensor:0')
knife_detection_boxes = knife_detection_graph.get_tensor_by_name(
    'detection_boxes:0')
knife_detection_scores = knife_detection_graph.get_tensor_by_name(
    'detection_scores:0')
knife_detection_classes = knife_detection_graph.get_tensor_by_name(
    'detection_classes:0')
knife_num_detections = knife_detection_graph.get_tensor_by_name(
    'num_detections:0')
# ---------------------------------------------------------------------------- #

# ------------------ General Model Initialization ---------------------------- #
general_label_map = label_map_util.load_labelmap('data/mscoco_label_map.pbtxt')
general_categories = label_map_util.convert_label_map_to_categories(
    general_label_map, max_num_classes=90, use_display_name=True)
general_category_index = label_map_util.create_category_index(
    general_categories)

general_detection_graph = tf.Graph()

with general_detection_graph.as_default():
    general_od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb', 'rb') as fid:
        general_serialized_graph = fid.read()
        general_od_graph_def.ParseFromString(general_serialized_graph)
        tf.import_graph_def(general_od_graph_def, name='')

    general_session = tf.Session(graph=general_detection_graph)

general_image_tensor = general_detection_graph.get_tensor_by_name(
    'image_tensor:0')
general_detection_boxes = general_detection_graph.get_tensor_by_name(
    'detection_boxes:0')
general_detection_scores = general_detection_graph.get_tensor_by_name(
    'detection_scores:0')
general_detection_classes = general_detection_graph.get_tensor_by_name(
    'detection_classes:0')
general_num_detections = general_detection_graph.get_tensor_by_name(
    'num_detections:0')
# ---------------------------------------------------------------------------- #


def knife(image_path):
    try:
        image = cv2.imread(image_path)
        image_expanded = np.expand_dims(image, axis=0)
        (boxes, scores, classes, num) = knife_session.run(
            [knife_detection_boxes, knife_detection_scores,
                knife_detection_classes, knife_num_detections],
            feed_dict={knife_image_tensor: image_expanded})

        classes = np.squeeze(classes).astype(np.int32)
        scores = np.squeeze(scores)
        boxes = np.squeeze(boxes)

        for c in range(0, len(classes)):
            class_name = knife_category_index[classes[c]]['name']
            if class_name == 'knife' and scores[c] > .80:
                confidence = scores[c] * 100
                break
            else:
                confidence = 0.00
    except:
        print("Error occurred in knife detection")
        confidence = 0.0   # Some error has occurred
    return confidence


def general(image_path):
    try:
        image = cv2.imread(image_path)
        image_expanded = np.expand_dims(image, axis=0)
        (boxes, scores, classes, num) = general_session.run(
            [general_detection_boxes, general_detection_scores,
                general_detection_classes, general_num_detections],
            feed_dict={general_image_tensor: image_expanded})

        classes = np.squeeze(classes).astype(np.int32)
        scores = np.squeeze(scores)
        boxes = np.squeeze(boxes)

        object_name = []
        object_score = []

        for c in range(0, len(classes)):
            class_name = general_category_index[classes[c]]['name']
            if scores[c] > .30:   # If confidence level is good enough
                object_name.append(class_name)
                object_score.append(str(scores[c] * 100)[:5])
    except:
        print("Error occurred in general detection")
        object_name = ['']
        object_score = ['']

    return object_name, object_score


if __name__ == '__main__':
    print(' in main')

I can do

import detect_object
detect_object.knife("image.jpg")  # to detect whether knife is present in image(this is custom trained model)

detect_object.general("image.jpg")  # to detect those 90 objects from TF API

I know there is knife model in TF API but it is not that much accurate so I retrained it for only knife. Finally I have two models 1. First model is to detect only knife, 2. Second model is to detect general object as usual

You cant combine both the models. Have two sections of code which will load one model at a time and identify whatever it can see in the image. Other option is to re-train a single model that can identify all objects you are interested in

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM