简体   繁体   English

如何从 TFLite Object Detection Python 中获取有用的数据

[英]How to get useful data from TFLite Object Detection Python

I have a raspberry pi 4, and I want to do object detection at a good frame rate.我有一个 raspberry pi 4,我想以良好的帧速率进行对象检测。 I tried tensorflow and YOLO but both run at 1 fps.我尝试过 tensorflow 和 YOLO,但都以 1 fps 的速度运行。 So I am trying TensorFlow Lite.所以我正在尝试 TensorFlow Lite。 I have downloaded the tflite file and the labelmap.txt file.我已经下载了 tflite 文件和 labelmap.txt 文件。 I have used this link to try to run inference.我已使用此链接尝试运行推理。 Here I faced a problem.在这里,我遇到了一个问题。 I do not understand how to get the results (classification, coor for bounding box and conf) out of the output.我不明白如何从输出中获得结果(分类、边界框坐标和 conf)。

Here is my code:这是我的代码:

import tensorflow as tf 
import numpy as np
import cv2

interpreter = tf.lite.Interpreter(model_path="/content/drive/My Drive/detect.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
print()

input_shape = input_details[0]['shape']
im = cv2.imread("/content/drive/My Drive/doggy.jpg")
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_rgb = cv2.resize(im_rgb, (input_shape[1], input_shape[2]))
input_data = np.expand_dims(im_rgb, axis=0)
print(input_data.shape)
print()

interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
print()
print(output_data)

And here is my output:这是我的输出:

[{'name': 'normalized_input_image_tensor', 'index': 175, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
[{'name': 'TFLite_Detection_PostProcess', 'index': 167, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 168, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 169, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 170, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

(1, 300, 300, 3)

(1, 10, 4)

[[[ 1.66415479e-02  5.48024022e-04  8.67791831e-01  3.35325867e-01]
  [ 7.41335377e-02  3.22245747e-01  9.64617252e-01  9.71388936e-01]
  [-2.11861148e-03  5.41743517e-01  2.60241032e-01  7.02846169e-01]
  [-5.67546487e-03  3.26282382e-01  8.59034657e-01  6.30770981e-01]
  [ 7.27111334e-03  7.90268779e-01  2.86753297e-01  9.56545353e-01]
  [ 2.07318692e-03  7.96441555e-01  5.48386931e-01  9.96111989e-01]
  [-1.04907183e-02  2.38761827e-01  6.75976276e-01  7.01156497e-01]
  [ 3.12007014e-02  1.34294275e-02  5.82291842e-01  3.10949832e-01]
  [-1.95578858e-03  7.05318868e-01  9.18281525e-02  7.96184599e-01]
  [-5.43205580e-03  3.23292404e-01  6.34427786e-01  5.68508685e-01]]]

The output (the last list) seems to be an array of very small numbers, how do I get the result out of this?输出(最后一个列表)似乎是一个非常小的数字数组,我如何从中得到结果?

Thanks谢谢

I solved this issue with the help of @daverim at github, where I had opened an issue.我在 github 上的 @daverim 的帮助下解决了这个问题,我在那里打开了一个问题。 https://github.com/tensorflow/tensorflow/issues/34761 . https://github.com/tensorflow/tensorflow/issues/34761 Here is the code to get useful data:这是获取有用数据的代码:

detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
print(num_boxes)
for i in range(int(num_boxes[0])):
  if detection_scores[0, i] > .5:
       class_id = detection_classes[0, i]
       print(class_id)

Using the labelmap.txt file we can get the class name.使用 labelmap.txt 文件,我们可以获得类名。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何从 facebook 页面中的 xhr 响应中获取有用的数据? - how to get useful data from xhr responses in facebook page? 如何从列表中的列表中获取有用的数据? - How can I get useful data from lists in a list? 如何从pyspeedtest结果中获取有用的数据? - How do I get useful data from pyspeedtest results? 使用用于 object 检测的元数据创建 tflite model - create tflite model with metadata for object detection Tensorflow对象检测-将.pb文件转换为tflite - Tensorflow Object Detection - Convert .pb file to tflite 如何从量化的 TFLite 中获取 class 索引? - How to get class indices from a quantized TFLite? 如何从Python中的对象获取数据 - How to get data from object in Python 将 SSD object 检测 model 转换为 TFLite 并将其从 float 量化为 uint8 用于 EdgeTPU - Converting SSD object detection model to TFLite and quantize it from float to uint8 for EdgeTPU 如何将对象检测模型(在其冻结图形中)转换为.tflite,而不需要任何输入和输出数组的知识 - How to convert an object detection model, in it's frozen graph, to a .tflite, without any knowledge of input and output arrays 如何评估 Python 中的 object 检测? - How to evaluate object detection in Python?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM