简体   繁体   中英

Issue in creating Tflite model populated with metadata (for object detection)

I am trying to run a tflite model on Android for object detection. For the same,

  1. I have successfully trained the model with my sets of images as follows:

(a) Training:

!python3 object_detection/model_main.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--model_dir=training/

(modifying the config file to point to where my specific TFrecords are mentioned)

(b) Export inference graph

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path}

(c) Create tflite ready graph

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_tflite_ssd_graph.py \
  --pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
  --output_directory={output_directory} \
  --trained_checkpoint_prefix={last_model_path} \
  --add_postprocessing_op=true
  1. I have created a tflite model using tflite_convert from the graph file as follows

    !tflite_convert
    --graph_def_file=/content/drive/My\\ Drive/Detecto\\ Tutorial/models/research/fine_tuned_model/tflite_graph.pb
    --output_file=/content/drive/My\\ Drive/Detecto\\ Tutorial/models/research/fine_tuned_model/detect3.tflite
    --output_format=TFLITE
    --input_shapes=1,300,300,3
    --input_arrays=normalized_input_image_tensor
    --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
    --inference_type=FLOAT
    --allow_custom_ops

The above tflite model is validated independently and works fine (outside of Android).

There is a requirement now to populate the tflite model with metadata so that it can be processed in the sample Android code provided as per link below (as I am getting an error otherwise: not a valid Zip file and does not have associated files when run on Android studio).

https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md

The sample .TFlite provided as part of the same link is populated with metadata and works fine.

When I try to use the following link: https://www.tensorflow.org/lite/convert/metadata#deep_dive_into_the_image_classification_example

populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
populator.populate()

to add metadata (rest of the code is practically the same with some changes of meta description to Object detection instead of Image classification and specifying the location of labelmap.txt), it gives the following error:

<ipython-input-6-173fc798ea6e> in <module>()
  1 populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
  ----> 2 populator.load_metadata_buffer(metadata_buf)
        3 populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
        4 populator.populate()

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_lite_support/metadata/metadata.py in _validate_metadata(self, metadata_buf)
    540           "The number of output tensors ({0}) should match the number of "
    541           "output tensor metadata ({1})".format(num_output_tensors,
--> 542                                                 num_output_meta))
    543 
    544 

ValueError: The number of output tensors (4) should match the number of output tensor metadata (1)

The 4 output tensors are the ones mentioned in the output_arrays in step 2 (someone may correct me there). I am not sure how to update output tensor metadata accordingly.

Can anyone who has recently used object detection using custom model (and then apply on Android) help? Or help understand how to update tensor metadata to 4 instead of 1.

Update on Jun 10, 2021:

See the latest tutorial about Metadata Writer Library on tensorflow.org.

Update :

The Metadata Writer library has been released. It currently supports image classifier and object detector, and more supported tasks are on the way.

Here is an example to write metadata for an object detector model:

  1. Install the TFLite Support nightly Pypi package:
pip install tflite_support_nightly
  1. Write metadata to the model using the following script:
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1_1_default_1.tflite"
_LABEL_FILE = "labelmap.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_1_default_1_metadata.tflite"

writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [127.5], [127.5], [_LABEL_FILE])
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

# Verify the populated metadata and associated files.
displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())

---------- Previous answer that writes metadata manually --------

Here is a code snippet you can use to populate metadata for object detection models, which is compatible with the TFLite Android app.

model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "SSD_Detector"
model_meta.description = (
    "Identify which of a known set of objects might be present and provide "
    "information about their positions within the given image or a video "
    "stream.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

# Creates outputs info.
output_location_meta = _metadata_fb.TensorMetadataT()
output_location_meta.name = "location"
output_location_meta.description = "The locations of the detected boxes."
output_location_meta.content = _metadata_fb.ContentT()
output_location_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.BoundingBoxProperties)
output_location_meta.content.contentProperties = (
    _metadata_fb.BoundingBoxPropertiesT())
output_location_meta.content.contentProperties.index = [1, 0, 3, 2]
output_location_meta.content.contentProperties.type = (
    _metadata_fb.BoundingBoxType.BOUNDARIES)
output_location_meta.content.contentProperties.coordinateType = (
    _metadata_fb.CoordinateType.RATIO)
output_location_meta.content.range = _metadata_fb.ValueRangeT()
output_location_meta.content.range.min = 2
output_location_meta.content.range.max = 2

output_class_meta = _metadata_fb.TensorMetadataT()
output_class_meta.name = "category"
output_class_meta.description = "The categories of the detected boxes."
output_class_meta.content = _metadata_fb.ContentT()
output_class_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_class_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_class_meta.content.range = _metadata_fb.ValueRangeT()
output_class_meta.content.range.min = 2
output_class_meta.content.range.max = 2
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("label.txt")
label_file.description = "Label of objects that this model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
output_class_meta.associatedFiles = [label_file]

output_score_meta = _metadata_fb.TensorMetadataT()
output_score_meta.name = "score"
output_score_meta.description = "The scores of the detected boxes."
output_score_meta.content = _metadata_fb.ContentT()
output_score_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_score_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_score_meta.content.range = _metadata_fb.ValueRangeT()
output_score_meta.content.range.min = 2
output_score_meta.content.range.max = 2

output_number_meta = _metadata_fb.TensorMetadataT()
output_number_meta.name = "number of detections"
output_number_meta.description = "The number of the detected boxes."
output_number_meta.content = _metadata_fb.ContentT()
output_number_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_number_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())

# Creates subgraph info.
group = _metadata_fb.TensorGroupT()
group.name = "detection result"
group.tensorNames = [
    output_location_meta.name, output_class_meta.name,
    output_score_meta.name
]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
    output_location_meta, output_class_meta, output_score_meta,
    output_number_meta
]
subgraph.outputTensorGroups = [group]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM