繁体   English   中英

在 Tensorflow 对象检测中根据类别索引打印标签/类名

[英]Printing label / class name based on category index in Tensorflow Object Detection

我需要帮助。 我想根据检测到的每个对象打印我的标签名称。 但我还是想不通怎么办? 如果有人可以帮助/提供一些指导,我将不胜感激。 谢谢。

我从其他帖子中读到我可以使用“print([category_index.get(i) for i in classes[0]])”,但它给了我一个错误:'numpy.int64'类型的参数不可迭代

IMAGE_SIZE = (24, 16)
IMAGE_PATH = os.path.join(paths['IMAGE_PATH'], 'test', '1.jpg')
img = cv2.imread(IMAGE_PATH)
image_np = np.array(img)

input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
detections = detect_fn(input_tensor)

num_detections = int(detections.pop('num_detections'))
detections = {key: value[0, :num_detections].numpy()
              for key, value in detections.items()}
detections['num_detections'] = num_detections

# detection_classes should be ints.
detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

label_id_offset = 1
image_np_with_detections = image_np.copy()

viz_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_detections,
            detections['detection_boxes'],
            detections['detection_classes']+label_id_offset,
            detections['detection_scores'],
            category_index,
            use_normalized_coordinates=True,
            max_boxes_to_draw=8,
            min_score_thresh=.25,
            agnostic_mode=False)
    
plt.figure(figsize=(IMAGE_SIZE))
plt.imshow(cv2.cvtColor(image_np_with_detections, cv2.COLOR_BGR2RGB)) 

#print ([category_index.get(value) for index, value in enumerate(classes[0]) if scores[0,index] > 0.25])
print(label['name'])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM