簡體   English   中英

創建填充有元數據的 Tflite 模型時出現問題(用於對象檢測)

[英]Issue in creating Tflite model populated with metadata (for object detection)

我正在嘗試在 Android 上運行 tflite 模型以進行對象檢測。 對於相同的,

  1. 我已經成功地用我的圖像集訓練了模型,如下所示:

(a) 培訓:

!python3 object_detection/model_main.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--model_dir=training/

(修改配置文件以指向提到我的特定 TFrecords 的位置)

(b) 導出推理圖

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path}

(c) 創建 tflite 就緒圖

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_tflite_ssd_graph.py \
  --pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
  --output_directory={output_directory} \
  --trained_checkpoint_prefix={last_model_path} \
  --add_postprocessing_op=true
  1. 我使用圖形文件中的 tflite_convert 創建了一個 tflite 模型,如下所示

    !tflite_convert
    --graph_def_file=/content/drive/My\\ Drive/Detecto\\ Tutorial/models/research/fine_tuned_model/tflite_graph.pb
    --output_file=/content/drive/My\\ Drive/Detecto\\ Tutorial/models/research/fine_tuned_model/detect3.tflite
    --output_format=TFLITE
    --input_shapes=1,300,300,3
    --input_arrays=normalized_input_image_tensor
    --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
    --inference_type=浮動
    --allow_custom_ops

上面的 tflite 模型經過獨立驗證並且運行良好(在 Android 之外)。

現在需要用元數據填充 tflite 模型,以便它可以在按照下面的鏈接提供的示例 Android 代碼中進行處理(因為我收到一個錯誤:不是有效的 Zip 文件並且在運行時沒有關聯的文件在 Android 工作室)。

https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md

作為同一鏈接的一部分提供的示例 .TFlite 填充了元數據並且工作正常。

當我嘗試使用以下鏈接時: https : //www.tensorflow.org/lite/convert/metadata#deep_dive_into_the_image_classification_example

populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
populator.populate()

添加元數據(其余代碼實際上相同,只是將元描述更改為對象檢測而不是圖像分類並指定 labelmap.txt 的位置),它給出了以下錯誤:

<ipython-input-6-173fc798ea6e> in <module>()
  1 populator = _metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
  ----> 2 populator.load_metadata_buffer(metadata_buf)
        3 populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
        4 populator.populate()

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_lite_support/metadata/metadata.py in _validate_metadata(self, metadata_buf)
    540           "The number of output tensors ({0}) should match the number of "
    541           "output tensor metadata ({1})".format(num_output_tensors,
--> 542                                                 num_output_meta))
    543 
    544 

ValueError: The number of output tensors (4) should match the number of output tensor metadata (1)

4 個輸出張量是第 2 步中 output_arrays 中提到的那些(有人可能會在那里糾正我)。 我不確定如何相應地更新輸出張量元數據。

最近使用自定義模型(然后在 Android 上應用)使用對象檢測的任何人都可以提供幫助嗎? 或者幫助理解如何將張量元數據更新為 4 而不是 1。

2021 年 6 月 10 日更新:

請參閱 tensorflow.org 上有關元數據編寫器庫的最新教程

更新

元數據編寫器庫已發布。 它目前支持圖像分類器和對象檢測器,更多支持的任務正在開發中。

以下是為對象檢測器模型編寫元數據的示例:

  1. 安裝 TFLite Support nightly Pypi 包:
pip install tflite_support_nightly
  1. 使用以下腳本將元數據寫入模型:
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
from tflite_support import metadata

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "ssd_mobilenet_v1_1_default_1.tflite"
_LABEL_FILE = "labelmap.txt"
_SAVE_TO_PATH = "ssd_mobilenet_v1_1_default_1_metadata.tflite"

writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [127.5], [127.5], [_LABEL_FILE])
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

# Verify the populated metadata and associated files.
displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())

---------- 手動寫入元數據的上一個答案 --------

這是一個代碼片段,您可以使用它來填充對象檢測模型的元數據,它與 TFLite Android 應用程序兼容。

model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "SSD_Detector"
model_meta.description = (
    "Identify which of a known set of objects might be present and provide "
    "information about their positions within the given image or a video "
    "stream.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

# Creates outputs info.
output_location_meta = _metadata_fb.TensorMetadataT()
output_location_meta.name = "location"
output_location_meta.description = "The locations of the detected boxes."
output_location_meta.content = _metadata_fb.ContentT()
output_location_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.BoundingBoxProperties)
output_location_meta.content.contentProperties = (
    _metadata_fb.BoundingBoxPropertiesT())
output_location_meta.content.contentProperties.index = [1, 0, 3, 2]
output_location_meta.content.contentProperties.type = (
    _metadata_fb.BoundingBoxType.BOUNDARIES)
output_location_meta.content.contentProperties.coordinateType = (
    _metadata_fb.CoordinateType.RATIO)
output_location_meta.content.range = _metadata_fb.ValueRangeT()
output_location_meta.content.range.min = 2
output_location_meta.content.range.max = 2

output_class_meta = _metadata_fb.TensorMetadataT()
output_class_meta.name = "category"
output_class_meta.description = "The categories of the detected boxes."
output_class_meta.content = _metadata_fb.ContentT()
output_class_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_class_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_class_meta.content.range = _metadata_fb.ValueRangeT()
output_class_meta.content.range.min = 2
output_class_meta.content.range.max = 2
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("label.txt")
label_file.description = "Label of objects that this model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
output_class_meta.associatedFiles = [label_file]

output_score_meta = _metadata_fb.TensorMetadataT()
output_score_meta.name = "score"
output_score_meta.description = "The scores of the detected boxes."
output_score_meta.content = _metadata_fb.ContentT()
output_score_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_score_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_score_meta.content.range = _metadata_fb.ValueRangeT()
output_score_meta.content.range.min = 2
output_score_meta.content.range.max = 2

output_number_meta = _metadata_fb.TensorMetadataT()
output_number_meta.name = "number of detections"
output_number_meta.description = "The number of the detected boxes."
output_number_meta.content = _metadata_fb.ContentT()
output_number_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_number_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())

# Creates subgraph info.
group = _metadata_fb.TensorGroupT()
group.name = "detection result"
group.tensorNames = [
    output_location_meta.name, output_class_meta.name,
    output_score_meta.name
]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
    output_location_meta, output_class_meta, output_score_meta,
    output_number_meta
]
subgraph.outputTensorGroups = [group]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM