简体   繁体   中英

Export TFlite ssd-mobilenet without NMS and with bounding boxes decoding logic using Tensorflow Object Detection API

I converted an ssd_mobilenet_v1 model in TFlite format using the Tensorflow Object Detection API, with the export_tflite_ssd_graph.py . Since I don't want the post-processing (NMS) in my final graph, I set the parameter --add_postprocessing_op to false. The exported model has two outputs raw_outputs/box_encodings and raw_outputs/class_predictions . In particular the raw_outputs/box_encodings contain raw bouding boxes, which need to be decoded using the anchor boxes, as explained here .

Is there a way to export this model, maintaining the bounding box decoding functionality within the graph?

This code would convert your checkpoint file into a TFLite file with NMS algorithm.

            # convert checkpoint file into TFLite compatible graph
            ssd_use_regular_nms = True
            centernet_include_keypoints = False
            keypoint_label_map_path =None
            max_detections = 20

            pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()

            with tf.io.gfile.GFile(pipelineFilePath, 'r') as f:
                text_format.Parse(f.read(), pipeline_config)

            export_tflite_graph_lib_tf2.export_tflite_model(
                pipeline_config, checkPointFileDir, outputDir,
                max_detections, ssd_use_regular_nms,
                centernet_include_keypoints, keypoint_label_map_path)
            print("Created tflite compatible graph from checkpoint file")
            # now build a tflite model file in outputDir
            #tf.compat.v1.disable_eager_execution()
            converter = tf.lite.TFLiteConverter.from_saved_model(os.path.join(outputDir, 'saved_model')) 
            converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
            tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
            ]
            tflite_model = converter.convert()

            self.TFLITE_MODEL_FILE = os.path.join(outputDir, 'model.tflite')
            with open(self.TFLITE_MODEL_FILE, 'wb') as f:
                f.write(tflite_model)
            print(f"Generated tflite model in {outputDir}")

You can then run inference as follows:

        interpreter = tf.lite.Interpreter(model_path=self.TFLITE_MODEL_FILE)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        _, h, w, _ = input_details[0]['shape']
    
        input_tensor = self.preprocess(input_image_path, h, w)
        interpreter.set_tensor(input_details[0]['index'], input_tensor)
        interpreter.invoke()

        # get results
        scores = interpreter.get_tensor( output_details[0]['index'])
        boxes = interpreter.get_tensor( output_details[1]['index'])
        num = interpreter.get_tensor(output_details[2]['index'])
        classes = interpreter.get_tensor(output_details[3]['index'])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM